1use forgeai_core::{
4 validate_request, ChatAdapter, ChatRequest, ChatResponse, ForgeError, Message, Role,
5 StreamEvent, StreamResult, ToolCall, Usage,
6};
7use forgeai_tools::ToolExecutor;
8use serde_json::{json, Value};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12pub struct Client {
13 adapter: Arc<dyn ChatAdapter>,
14}
15
16impl Client {
17 pub fn new(adapter: Arc<dyn ChatAdapter>) -> Self {
18 Self { adapter }
19 }
20
21 pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ForgeError> {
22 validate_request(&request)?;
23 self.adapter.chat(request).await
24 }
25
26 pub async fn chat_stream(
27 &self,
28 request: ChatRequest,
29 ) -> Result<StreamResult<StreamEvent>, ForgeError> {
30 validate_request(&request)?;
31 self.adapter.chat_stream(request).await
32 }
33
34 pub async fn chat_with_tools(
35 &self,
36 request: ChatRequest,
37 tools: &dyn ToolExecutor,
38 options: ToolLoopOptions,
39 ) -> Result<ToolLoopResult, ForgeError> {
40 run_tool_loop(self, request, tools, options, false).await
41 }
42
43 pub async fn chat_with_tools_streaming(
44 &self,
45 request: ChatRequest,
46 tools: &dyn ToolExecutor,
47 options: ToolLoopOptions,
48 ) -> Result<ToolLoopResult, ForgeError> {
49 run_tool_loop(self, request, tools, options, true).await
50 }
51}
52
53#[derive(Debug, Clone)]
54pub struct ToolLoopOptions {
55 pub max_iterations: usize,
56}
57
58impl Default for ToolLoopOptions {
59 fn default() -> Self {
60 Self { max_iterations: 8 }
61 }
62}
63
64#[derive(Debug, Clone)]
65pub struct ToolInvocation {
66 pub call_id: String,
67 pub name: String,
68 pub input: Value,
69 pub output: Value,
70}
71
72#[derive(Debug, Clone)]
73pub struct ToolLoopResult {
74 pub final_response: ChatResponse,
75 pub tool_invocations: Vec<ToolInvocation>,
76 pub iterations: usize,
77}
78
79async fn run_tool_loop(
80 client: &Client,
81 mut request: ChatRequest,
82 tools: &dyn ToolExecutor,
83 options: ToolLoopOptions,
84 use_streaming: bool,
85) -> Result<ToolLoopResult, ForgeError> {
86 validate_request(&request)?;
87 if options.max_iterations == 0 {
88 return Err(ForgeError::Validation(
89 "max_iterations must be greater than 0".to_string(),
90 ));
91 }
92
93 let mut invocations = Vec::new();
94
95 for iteration in 0..options.max_iterations {
96 let response = if use_streaming {
97 client.chat_stream_collect(request.clone()).await?
98 } else {
99 client.adapter.chat(request.clone()).await?
100 };
101
102 if response.tool_calls.is_empty() {
103 return Ok(ToolLoopResult {
104 final_response: response,
105 tool_invocations: invocations,
106 iterations: iteration + 1,
107 });
108 }
109
110 request.messages.push(Message {
111 role: Role::Assistant,
112 content: response.output_text.clone(),
113 });
114
115 for call in response.tool_calls {
116 let output = tools
117 .call(&call.name, call.arguments.clone())
118 .map_err(|e| {
119 ForgeError::Provider(format!("tool '{}' execution failed: {e}", call.name))
120 })?;
121
122 invocations.push(ToolInvocation {
123 call_id: call.id.clone(),
124 name: call.name.clone(),
125 input: call.arguments.clone(),
126 output: output.clone(),
127 });
128
129 request.messages.push(Message {
130 role: Role::Tool,
131 content: json!({
132 "tool_call_id": call.id,
133 "name": call.name,
134 "output": output
135 })
136 .to_string(),
137 });
138 }
139 }
140
141 Err(ForgeError::Provider(format!(
142 "tool loop exceeded max iterations ({})",
143 options.max_iterations
144 )))
145}
146
147impl Client {
148 async fn chat_stream_collect(&self, request: ChatRequest) -> Result<ChatResponse, ForgeError> {
149 let mut stream = self.chat_stream(request.clone()).await?;
150 let mut text = String::new();
151 let mut usage: Option<Usage> = None;
152 let mut tool_call_deltas: HashMap<String, Value> = HashMap::new();
153
154 use futures_util::StreamExt;
155 while let Some(item) = stream.next().await {
156 match item? {
157 StreamEvent::TextDelta { delta } => text.push_str(&delta),
158 StreamEvent::Usage { usage: u } => usage = Some(u),
159 StreamEvent::ToolCallDelta { call_id, delta } => {
160 tool_call_deltas.insert(call_id, delta);
161 }
162 StreamEvent::Done => break,
163 }
164 }
165
166 let tool_calls = tool_call_deltas
167 .into_iter()
168 .map(|(call_id, delta)| {
169 let name = delta
171 .get("name")
172 .and_then(Value::as_str)
173 .or_else(|| {
174 delta
175 .get("function")
176 .and_then(|f| f.get("name"))
177 .and_then(Value::as_str)
178 })
179 .unwrap_or("unknown_tool")
180 .to_string();
181 let arguments = delta
182 .get("arguments")
183 .cloned()
184 .or_else(|| {
185 delta
186 .get("function")
187 .and_then(|f| f.get("arguments"))
188 .cloned()
189 })
190 .unwrap_or(Value::Null);
191 ToolCall {
192 id: call_id,
193 name,
194 arguments,
195 }
196 })
197 .collect();
198
199 Ok(ChatResponse {
200 id: "stream-collected".to_string(),
201 model: request.model,
202 output_text: text,
203 tool_calls,
204 usage,
205 })
206 }
207}
208
209pub use forgeai_core;
210pub use forgeai_tools;
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use async_stream::try_stream;
216 use async_trait::async_trait;
217 use forgeai_core::{AdapterInfo, CapabilityMatrix};
218 use serde_json::json;
219 use std::collections::VecDeque;
220 use std::sync::Mutex;
221
222 struct MockAdapter {
223 chat_responses: Mutex<VecDeque<ChatResponse>>,
224 stream_responses: Mutex<VecDeque<Vec<StreamEvent>>>,
225 }
226
227 impl MockAdapter {
228 fn with_chat_responses(items: Vec<ChatResponse>) -> Self {
229 Self {
230 chat_responses: Mutex::new(VecDeque::from(items)),
231 stream_responses: Mutex::new(VecDeque::new()),
232 }
233 }
234
235 fn with_stream_responses(items: Vec<Vec<StreamEvent>>) -> Self {
236 Self {
237 chat_responses: Mutex::new(VecDeque::new()),
238 stream_responses: Mutex::new(VecDeque::from(items)),
239 }
240 }
241 }
242
243 #[async_trait]
244 impl ChatAdapter for MockAdapter {
245 fn info(&self) -> AdapterInfo {
246 AdapterInfo {
247 name: "mock".to_string(),
248 base_url: None,
249 capabilities: CapabilityMatrix {
250 streaming: true,
251 tools: true,
252 structured_output: true,
253 multimodal_input: false,
254 citations: false,
255 },
256 }
257 }
258
259 async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ForgeError> {
260 self.chat_responses
261 .lock()
262 .map_err(|_| ForgeError::Internal("lock poisoned".to_string()))?
263 .pop_front()
264 .ok_or_else(|| ForgeError::Internal("no mock chat response remaining".to_string()))
265 }
266
267 async fn chat_stream(
268 &self,
269 _request: ChatRequest,
270 ) -> Result<StreamResult<StreamEvent>, ForgeError> {
271 let events = self
272 .stream_responses
273 .lock()
274 .map_err(|_| ForgeError::Internal("lock poisoned".to_string()))?
275 .pop_front()
276 .ok_or_else(|| {
277 ForgeError::Internal("no mock stream response remaining".to_string())
278 })?;
279
280 let stream = try_stream! {
281 for event in events {
282 yield event;
283 }
284 };
285 Ok(Box::pin(stream))
286 }
287 }
288
289 struct EchoTools;
290
291 impl ToolExecutor for EchoTools {
292 fn call(&self, _name: &str, input: Value) -> Result<Value, forgeai_tools::ToolError> {
293 Ok(json!({ "echo": input }))
294 }
295 }
296
297 fn base_request() -> ChatRequest {
298 ChatRequest {
299 model: "mock-model".to_string(),
300 messages: vec![Message {
301 role: Role::User,
302 content: "what time is it?".to_string(),
303 }],
304 temperature: Some(0.1),
305 max_tokens: Some(128),
306 tools: vec![],
307 metadata: json!({}),
308 }
309 }
310
311 #[tokio::test]
312 async fn chat_with_tools_runs_loop_until_final_answer() {
313 let adapter = MockAdapter::with_chat_responses(vec![
314 ChatResponse {
315 id: "1".to_string(),
316 model: "mock-model".to_string(),
317 output_text: "".to_string(),
318 tool_calls: vec![ToolCall {
319 id: "call-1".to_string(),
320 name: "time.now".to_string(),
321 arguments: json!({"timezone":"UTC"}),
322 }],
323 usage: None,
324 },
325 ChatResponse {
326 id: "2".to_string(),
327 model: "mock-model".to_string(),
328 output_text: "Current UTC time is 12:00".to_string(),
329 tool_calls: vec![],
330 usage: None,
331 },
332 ]);
333
334 let client = Client::new(Arc::new(adapter));
335 let result = client
336 .chat_with_tools(base_request(), &EchoTools, ToolLoopOptions::default())
337 .await
338 .unwrap();
339
340 assert_eq!(
341 result.final_response.output_text,
342 "Current UTC time is 12:00"
343 );
344 assert_eq!(result.tool_invocations.len(), 1);
345 assert_eq!(result.tool_invocations[0].name, "time.now");
346 assert_eq!(result.iterations, 2);
347 }
348
349 #[tokio::test]
350 async fn chat_with_tools_streaming_collects_events_and_executes_tools() {
351 let adapter = MockAdapter::with_stream_responses(vec![
352 vec![
353 StreamEvent::ToolCallDelta {
354 call_id: "call-1".to_string(),
355 delta: json!({"name":"time.now","arguments":{"timezone":"UTC"}}),
356 },
357 StreamEvent::Done,
358 ],
359 vec![
360 StreamEvent::TextDelta {
361 delta: "Current UTC time is 12:00".to_string(),
362 },
363 StreamEvent::Done,
364 ],
365 ]);
366
367 let client = Client::new(Arc::new(adapter));
368 let result = client
369 .chat_with_tools_streaming(base_request(), &EchoTools, ToolLoopOptions::default())
370 .await
371 .unwrap();
372
373 assert_eq!(
374 result.final_response.output_text,
375 "Current UTC time is 12:00"
376 );
377 assert_eq!(result.tool_invocations.len(), 1);
378 assert_eq!(result.iterations, 2);
379 }
380
381 #[tokio::test]
382 async fn chat_with_tools_honors_max_iterations() {
383 let adapter = MockAdapter::with_chat_responses(vec![ChatResponse {
384 id: "1".to_string(),
385 model: "mock-model".to_string(),
386 output_text: "".to_string(),
387 tool_calls: vec![ToolCall {
388 id: "call-1".to_string(),
389 name: "loop.forever".to_string(),
390 arguments: json!({}),
391 }],
392 usage: None,
393 }]);
394
395 let client = Client::new(Arc::new(adapter));
396 let err = client
397 .chat_with_tools(
398 base_request(),
399 &EchoTools,
400 ToolLoopOptions { max_iterations: 1 },
401 )
402 .await
403 .unwrap_err();
404
405 assert!(matches!(err, ForgeError::Provider(_)));
406 }
407}