1use std::collections::VecDeque;
2use std::error::Error;
3use std::sync::Mutex;
4
5use agent_sdk_rs::{
6 Agent, AgentEvent, ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice,
7 ModelToolDefinition, ProviderError, ToolError, ToolOutcome, ToolSpec,
8};
9use async_trait::async_trait;
10use futures_util::StreamExt;
11use serde_json::json;
12
13#[derive(Default)]
14struct ScriptedModel {
15 responses: Mutex<VecDeque<Result<ModelCompletion, ProviderError>>>,
16}
17
18impl ScriptedModel {
19 fn new(responses: Vec<Result<ModelCompletion, ProviderError>>) -> Self {
20 Self {
21 responses: Mutex::new(VecDeque::from(responses)),
22 }
23 }
24}
25
26#[async_trait]
27impl ChatModel for ScriptedModel {
28 async fn invoke(
29 &self,
30 _messages: &[ModelMessage],
31 _tools: &[ModelToolDefinition],
32 _tool_choice: ModelToolChoice,
33 ) -> Result<ModelCompletion, ProviderError> {
34 let mut guard = self.responses.lock().expect("lock poisoned");
35 guard.pop_front().unwrap_or_else(|| {
36 Err(ProviderError::Response(
37 "scripted model exhausted responses".to_string(),
38 ))
39 })
40 }
41}
42
43fn add_tool() -> ToolSpec {
44 ToolSpec::new("add", "add two numbers")
45 .with_schema(json!({
46 "type": "object",
47 "properties": {
48 "a": {"type": "integer"},
49 "b": {"type": "integer"}
50 },
51 "required": ["a", "b"],
52 "additionalProperties": false
53 }))
54 .expect("valid schema")
55 .with_handler(|args, _deps| async move {
56 let a = args
57 .get("a")
58 .and_then(|v| v.as_i64())
59 .ok_or_else(|| ToolError::Execution("a missing".to_string()))?;
60 let b = args
61 .get("b")
62 .and_then(|v| v.as_i64())
63 .ok_or_else(|| ToolError::Execution("b missing".to_string()))?;
64 Ok(ToolOutcome::Text((a + b).to_string()))
65 })
66}
67
68fn done_tool() -> ToolSpec {
69 ToolSpec::new("done", "complete and return")
70 .with_schema(json!({
71 "type": "object",
72 "properties": {
73 "message": {"type": "string"}
74 },
75 "required": ["message"],
76 "additionalProperties": false
77 }))
78 .expect("valid schema")
79 .with_handler(|args, _deps| async move {
80 let message = args
81 .get("message")
82 .and_then(|v| v.as_str())
83 .ok_or_else(|| ToolError::Execution("message missing".to_string()))?;
84 Ok(ToolOutcome::Done(message.to_string()))
85 })
86}
87
88fn build_agent(responses: Vec<Result<ModelCompletion, ProviderError>>) -> Agent {
89 Agent::builder()
90 .model(ScriptedModel::new(responses))
91 .tool(add_tool())
92 .tool(done_tool())
93 .build()
94 .expect("agent builds")
95}
96
97#[tokio::main]
98async fn main() -> Result<(), Box<dyn Error>> {
99 let mut agent = build_agent(vec![
100 Ok(ModelCompletion {
101 text: Some("Working on it".to_string()),
102 thinking: Some("Need arithmetic".to_string()),
103 tool_calls: vec![ModelToolCall {
104 id: "call_1".to_string(),
105 name: "add".to_string(),
106 arguments: json!({"a": 2, "b": 3}),
107 }],
108 usage: None,
109 }),
110 Ok(ModelCompletion {
111 text: None,
112 thinking: None,
113 tool_calls: vec![ModelToolCall {
114 id: "call_2".to_string(),
115 name: "done".to_string(),
116 arguments: json!({"message": "2 + 3 = 5"}),
117 }],
118 usage: None,
119 }),
120 ]);
121
122 let final_response = agent.query("What is 2 + 3?").await?;
123 println!("query final: {final_response}");
124
125 let mut streaming_agent = build_agent(vec![
126 Ok(ModelCompletion {
127 text: Some("Streaming run".to_string()),
128 thinking: Some("Will call add and done".to_string()),
129 tool_calls: vec![ModelToolCall {
130 id: "call_3".to_string(),
131 name: "add".to_string(),
132 arguments: json!({"a": 10, "b": 7}),
133 }],
134 usage: None,
135 }),
136 Ok(ModelCompletion {
137 text: None,
138 thinking: None,
139 tool_calls: vec![ModelToolCall {
140 id: "call_4".to_string(),
141 name: "done".to_string(),
142 arguments: json!({"message": "10 + 7 = 17"}),
143 }],
144 usage: None,
145 }),
146 ]);
147
148 let stream = streaming_agent.query_stream("What is 10 + 7?");
149 futures_util::pin_mut!(stream);
150 while let Some(event) = stream.next().await {
151 match event? {
152 AgentEvent::MessageStart { message_id, role } => {
153 println!("message start [{message_id}] {role:?}")
154 }
155 AgentEvent::MessageComplete {
156 message_id,
157 content,
158 } => println!("message complete [{message_id}]: {content}"),
159 AgentEvent::HiddenUserMessage { content } => println!("hidden: {content}"),
160 AgentEvent::StepStart {
161 step_id,
162 title,
163 step_number,
164 } => println!("step start [{step_id}] #{step_number} {title}"),
165 AgentEvent::StepComplete {
166 step_id,
167 status,
168 duration_ms,
169 } => println!("step complete [{step_id}] {status:?} ({duration_ms} ms)"),
170 AgentEvent::Thinking { content } => println!("thinking: {content}"),
171 AgentEvent::Text { content } => println!("text: {content}"),
172 AgentEvent::ToolCall {
173 tool,
174 args_json,
175 tool_call_id,
176 } => println!("tool call [{tool_call_id}] {tool}: {args_json}"),
177 AgentEvent::ToolResult {
178 tool,
179 result_text,
180 tool_call_id,
181 is_error,
182 } => println!("tool result [{tool_call_id}] {tool}: {result_text} (error={is_error})"),
183 AgentEvent::FinalResponse { content } => println!("stream final: {content}"),
184 }
185 }
186
187 Ok(())
188}