Skip to main content

local_loop/
local_loop.rs

1use std::collections::VecDeque;
2use std::error::Error;
3use std::sync::Mutex;
4
5use agent_sdk_rs::{
6    Agent, AgentEvent, ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice,
7    ModelToolDefinition, ProviderError, ToolError, ToolOutcome, ToolSpec,
8};
9use async_trait::async_trait;
10use futures_util::StreamExt;
11use serde_json::json;
12
13#[derive(Default)]
14struct ScriptedModel {
15    responses: Mutex<VecDeque<Result<ModelCompletion, ProviderError>>>,
16}
17
18impl ScriptedModel {
19    fn new(responses: Vec<Result<ModelCompletion, ProviderError>>) -> Self {
20        Self {
21            responses: Mutex::new(VecDeque::from(responses)),
22        }
23    }
24}
25
26#[async_trait]
27impl ChatModel for ScriptedModel {
28    async fn invoke(
29        &self,
30        _messages: &[ModelMessage],
31        _tools: &[ModelToolDefinition],
32        _tool_choice: ModelToolChoice,
33    ) -> Result<ModelCompletion, ProviderError> {
34        let mut guard = self.responses.lock().expect("lock poisoned");
35        guard.pop_front().unwrap_or_else(|| {
36            Err(ProviderError::Response(
37                "scripted model exhausted responses".to_string(),
38            ))
39        })
40    }
41}
42
43fn add_tool() -> ToolSpec {
44    ToolSpec::new("add", "add two numbers")
45        .with_schema(json!({
46            "type": "object",
47            "properties": {
48                "a": {"type": "integer"},
49                "b": {"type": "integer"}
50            },
51            "required": ["a", "b"],
52            "additionalProperties": false
53        }))
54        .expect("valid schema")
55        .with_handler(|args, _deps| async move {
56            let a = args
57                .get("a")
58                .and_then(|v| v.as_i64())
59                .ok_or_else(|| ToolError::Execution("a missing".to_string()))?;
60            let b = args
61                .get("b")
62                .and_then(|v| v.as_i64())
63                .ok_or_else(|| ToolError::Execution("b missing".to_string()))?;
64            Ok(ToolOutcome::Text((a + b).to_string()))
65        })
66}
67
68fn done_tool() -> ToolSpec {
69    ToolSpec::new("done", "complete and return")
70        .with_schema(json!({
71            "type": "object",
72            "properties": {
73                "message": {"type": "string"}
74            },
75            "required": ["message"],
76            "additionalProperties": false
77        }))
78        .expect("valid schema")
79        .with_handler(|args, _deps| async move {
80            let message = args
81                .get("message")
82                .and_then(|v| v.as_str())
83                .ok_or_else(|| ToolError::Execution("message missing".to_string()))?;
84            Ok(ToolOutcome::Done(message.to_string()))
85        })
86}
87
88fn build_agent(responses: Vec<Result<ModelCompletion, ProviderError>>) -> Agent {
89    Agent::builder()
90        .model(ScriptedModel::new(responses))
91        .tool(add_tool())
92        .tool(done_tool())
93        .build()
94        .expect("agent builds")
95}
96
97#[tokio::main]
98async fn main() -> Result<(), Box<dyn Error>> {
99    let mut agent = build_agent(vec![
100        Ok(ModelCompletion {
101            text: Some("Working on it".to_string()),
102            thinking: Some("Need arithmetic".to_string()),
103            tool_calls: vec![ModelToolCall {
104                id: "call_1".to_string(),
105                name: "add".to_string(),
106                arguments: json!({"a": 2, "b": 3}),
107            }],
108            usage: None,
109        }),
110        Ok(ModelCompletion {
111            text: None,
112            thinking: None,
113            tool_calls: vec![ModelToolCall {
114                id: "call_2".to_string(),
115                name: "done".to_string(),
116                arguments: json!({"message": "2 + 3 = 5"}),
117            }],
118            usage: None,
119        }),
120    ]);
121
122    let final_response = agent.query("What is 2 + 3?").await?;
123    println!("query final: {final_response}");
124
125    let mut streaming_agent = build_agent(vec![
126        Ok(ModelCompletion {
127            text: Some("Streaming run".to_string()),
128            thinking: Some("Will call add and done".to_string()),
129            tool_calls: vec![ModelToolCall {
130                id: "call_3".to_string(),
131                name: "add".to_string(),
132                arguments: json!({"a": 10, "b": 7}),
133            }],
134            usage: None,
135        }),
136        Ok(ModelCompletion {
137            text: None,
138            thinking: None,
139            tool_calls: vec![ModelToolCall {
140                id: "call_4".to_string(),
141                name: "done".to_string(),
142                arguments: json!({"message": "10 + 7 = 17"}),
143            }],
144            usage: None,
145        }),
146    ]);
147
148    let stream = streaming_agent.query_stream("What is 10 + 7?");
149    futures_util::pin_mut!(stream);
150    while let Some(event) = stream.next().await {
151        match event? {
152            AgentEvent::MessageStart { message_id, role } => {
153                println!("message start [{message_id}] {role:?}")
154            }
155            AgentEvent::MessageComplete {
156                message_id,
157                content,
158            } => println!("message complete [{message_id}]: {content}"),
159            AgentEvent::HiddenUserMessage { content } => println!("hidden: {content}"),
160            AgentEvent::StepStart {
161                step_id,
162                title,
163                step_number,
164            } => println!("step start [{step_id}] #{step_number} {title}"),
165            AgentEvent::StepComplete {
166                step_id,
167                status,
168                duration_ms,
169            } => println!("step complete [{step_id}] {status:?} ({duration_ms} ms)"),
170            AgentEvent::Thinking { content } => println!("thinking: {content}"),
171            AgentEvent::Text { content } => println!("text: {content}"),
172            AgentEvent::ToolCall {
173                tool,
174                args_json,
175                tool_call_id,
176            } => println!("tool call [{tool_call_id}] {tool}: {args_json}"),
177            AgentEvent::ToolResult {
178                tool,
179                result_text,
180                tool_call_id,
181                is_error,
182            } => println!("tool result [{tool_call_id}] {tool}: {result_text} (error={is_error})"),
183            AgentEvent::FinalResponse { content } => println!("stream final: {content}"),
184        }
185    }
186
187    Ok(())
188}