Skip to main content

forgeai/
lib.rs

1//! High-level forgeai SDK.
2
3use forgeai_core::{
4    validate_request, ChatAdapter, ChatRequest, ChatResponse, ForgeError, Message, Role,
5    StreamEvent, StreamResult, ToolCall, Usage,
6};
7use forgeai_tools::ToolExecutor;
8use serde_json::{json, Value};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12pub struct Client {
13    adapter: Arc<dyn ChatAdapter>,
14}
15
16impl Client {
17    pub fn new(adapter: Arc<dyn ChatAdapter>) -> Self {
18        Self { adapter }
19    }
20
21    pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ForgeError> {
22        validate_request(&request)?;
23        self.adapter.chat(request).await
24    }
25
26    pub async fn chat_stream(
27        &self,
28        request: ChatRequest,
29    ) -> Result<StreamResult<StreamEvent>, ForgeError> {
30        validate_request(&request)?;
31        self.adapter.chat_stream(request).await
32    }
33
34    pub async fn chat_with_tools(
35        &self,
36        request: ChatRequest,
37        tools: &dyn ToolExecutor,
38        options: ToolLoopOptions,
39    ) -> Result<ToolLoopResult, ForgeError> {
40        run_tool_loop(self, request, tools, options, false).await
41    }
42
43    pub async fn chat_with_tools_streaming(
44        &self,
45        request: ChatRequest,
46        tools: &dyn ToolExecutor,
47        options: ToolLoopOptions,
48    ) -> Result<ToolLoopResult, ForgeError> {
49        run_tool_loop(self, request, tools, options, true).await
50    }
51}
52
53#[derive(Debug, Clone)]
54pub struct ToolLoopOptions {
55    pub max_iterations: usize,
56}
57
58impl Default for ToolLoopOptions {
59    fn default() -> Self {
60        Self { max_iterations: 8 }
61    }
62}
63
64#[derive(Debug, Clone)]
65pub struct ToolInvocation {
66    pub call_id: String,
67    pub name: String,
68    pub input: Value,
69    pub output: Value,
70}
71
72#[derive(Debug, Clone)]
73pub struct ToolLoopResult {
74    pub final_response: ChatResponse,
75    pub tool_invocations: Vec<ToolInvocation>,
76    pub iterations: usize,
77}
78
79async fn run_tool_loop(
80    client: &Client,
81    mut request: ChatRequest,
82    tools: &dyn ToolExecutor,
83    options: ToolLoopOptions,
84    use_streaming: bool,
85) -> Result<ToolLoopResult, ForgeError> {
86    validate_request(&request)?;
87    if options.max_iterations == 0 {
88        return Err(ForgeError::Validation(
89            "max_iterations must be greater than 0".to_string(),
90        ));
91    }
92
93    let mut invocations = Vec::new();
94
95    for iteration in 0..options.max_iterations {
96        let response = if use_streaming {
97            client.chat_stream_collect(request.clone()).await?
98        } else {
99            client.adapter.chat(request.clone()).await?
100        };
101
102        if response.tool_calls.is_empty() {
103            return Ok(ToolLoopResult {
104                final_response: response,
105                tool_invocations: invocations,
106                iterations: iteration + 1,
107            });
108        }
109
110        request.messages.push(Message {
111            role: Role::Assistant,
112            content: response.output_text.clone(),
113        });
114
115        for call in response.tool_calls {
116            let output = tools
117                .call(&call.name, call.arguments.clone())
118                .map_err(|e| {
119                    ForgeError::Provider(format!("tool '{}' execution failed: {e}", call.name))
120                })?;
121
122            invocations.push(ToolInvocation {
123                call_id: call.id.clone(),
124                name: call.name.clone(),
125                input: call.arguments.clone(),
126                output: output.clone(),
127            });
128
129            request.messages.push(Message {
130                role: Role::Tool,
131                content: json!({
132                    "tool_call_id": call.id,
133                    "name": call.name,
134                    "output": output
135                })
136                .to_string(),
137            });
138        }
139    }
140
141    Err(ForgeError::Provider(format!(
142        "tool loop exceeded max iterations ({})",
143        options.max_iterations
144    )))
145}
146
147impl Client {
148    async fn chat_stream_collect(&self, request: ChatRequest) -> Result<ChatResponse, ForgeError> {
149        let mut stream = self.chat_stream(request.clone()).await?;
150        let mut text = String::new();
151        let mut usage: Option<Usage> = None;
152        let mut tool_call_deltas: HashMap<String, Value> = HashMap::new();
153
154        use futures_util::StreamExt;
155        while let Some(item) = stream.next().await {
156            match item? {
157                StreamEvent::TextDelta { delta } => text.push_str(&delta),
158                StreamEvent::Usage { usage: u } => usage = Some(u),
159                StreamEvent::ToolCallDelta { call_id, delta } => {
160                    tool_call_deltas.insert(call_id, delta);
161                }
162                StreamEvent::Done => break,
163            }
164        }
165
166        let tool_calls = tool_call_deltas
167            .into_iter()
168            .map(|(call_id, delta)| {
169                // Best-effort normalization across provider stream formats.
170                let name = delta
171                    .get("name")
172                    .and_then(Value::as_str)
173                    .or_else(|| {
174                        delta
175                            .get("function")
176                            .and_then(|f| f.get("name"))
177                            .and_then(Value::as_str)
178                    })
179                    .unwrap_or("unknown_tool")
180                    .to_string();
181                let arguments = delta
182                    .get("arguments")
183                    .cloned()
184                    .or_else(|| {
185                        delta
186                            .get("function")
187                            .and_then(|f| f.get("arguments"))
188                            .cloned()
189                    })
190                    .unwrap_or(Value::Null);
191                ToolCall {
192                    id: call_id,
193                    name,
194                    arguments,
195                }
196            })
197            .collect();
198
199        Ok(ChatResponse {
200            id: "stream-collected".to_string(),
201            model: request.model,
202            output_text: text,
203            tool_calls,
204            usage,
205        })
206    }
207}
208
209pub use forgeai_core;
210pub use forgeai_tools;
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use async_stream::try_stream;
216    use async_trait::async_trait;
217    use forgeai_core::{AdapterInfo, CapabilityMatrix};
218    use serde_json::json;
219    use std::collections::VecDeque;
220    use std::sync::Mutex;
221
222    struct MockAdapter {
223        chat_responses: Mutex<VecDeque<ChatResponse>>,
224        stream_responses: Mutex<VecDeque<Vec<StreamEvent>>>,
225    }
226
227    impl MockAdapter {
228        fn with_chat_responses(items: Vec<ChatResponse>) -> Self {
229            Self {
230                chat_responses: Mutex::new(VecDeque::from(items)),
231                stream_responses: Mutex::new(VecDeque::new()),
232            }
233        }
234
235        fn with_stream_responses(items: Vec<Vec<StreamEvent>>) -> Self {
236            Self {
237                chat_responses: Mutex::new(VecDeque::new()),
238                stream_responses: Mutex::new(VecDeque::from(items)),
239            }
240        }
241    }
242
243    #[async_trait]
244    impl ChatAdapter for MockAdapter {
245        fn info(&self) -> AdapterInfo {
246            AdapterInfo {
247                name: "mock".to_string(),
248                base_url: None,
249                capabilities: CapabilityMatrix {
250                    streaming: true,
251                    tools: true,
252                    structured_output: true,
253                    multimodal_input: false,
254                    citations: false,
255                },
256            }
257        }
258
259        async fn chat(&self, _request: ChatRequest) -> Result<ChatResponse, ForgeError> {
260            self.chat_responses
261                .lock()
262                .map_err(|_| ForgeError::Internal("lock poisoned".to_string()))?
263                .pop_front()
264                .ok_or_else(|| ForgeError::Internal("no mock chat response remaining".to_string()))
265        }
266
267        async fn chat_stream(
268            &self,
269            _request: ChatRequest,
270        ) -> Result<StreamResult<StreamEvent>, ForgeError> {
271            let events = self
272                .stream_responses
273                .lock()
274                .map_err(|_| ForgeError::Internal("lock poisoned".to_string()))?
275                .pop_front()
276                .ok_or_else(|| {
277                    ForgeError::Internal("no mock stream response remaining".to_string())
278                })?;
279
280            let stream = try_stream! {
281                for event in events {
282                    yield event;
283                }
284            };
285            Ok(Box::pin(stream))
286        }
287    }
288
289    struct EchoTools;
290
291    impl ToolExecutor for EchoTools {
292        fn call(&self, _name: &str, input: Value) -> Result<Value, forgeai_tools::ToolError> {
293            Ok(json!({ "echo": input }))
294        }
295    }
296
297    fn base_request() -> ChatRequest {
298        ChatRequest {
299            model: "mock-model".to_string(),
300            messages: vec![Message {
301                role: Role::User,
302                content: "what time is it?".to_string(),
303            }],
304            temperature: Some(0.1),
305            max_tokens: Some(128),
306            tools: vec![],
307            metadata: json!({}),
308        }
309    }
310
311    #[tokio::test]
312    async fn chat_with_tools_runs_loop_until_final_answer() {
313        let adapter = MockAdapter::with_chat_responses(vec![
314            ChatResponse {
315                id: "1".to_string(),
316                model: "mock-model".to_string(),
317                output_text: "".to_string(),
318                tool_calls: vec![ToolCall {
319                    id: "call-1".to_string(),
320                    name: "time.now".to_string(),
321                    arguments: json!({"timezone":"UTC"}),
322                }],
323                usage: None,
324            },
325            ChatResponse {
326                id: "2".to_string(),
327                model: "mock-model".to_string(),
328                output_text: "Current UTC time is 12:00".to_string(),
329                tool_calls: vec![],
330                usage: None,
331            },
332        ]);
333
334        let client = Client::new(Arc::new(adapter));
335        let result = client
336            .chat_with_tools(base_request(), &EchoTools, ToolLoopOptions::default())
337            .await
338            .unwrap();
339
340        assert_eq!(
341            result.final_response.output_text,
342            "Current UTC time is 12:00"
343        );
344        assert_eq!(result.tool_invocations.len(), 1);
345        assert_eq!(result.tool_invocations[0].name, "time.now");
346        assert_eq!(result.iterations, 2);
347    }
348
349    #[tokio::test]
350    async fn chat_with_tools_streaming_collects_events_and_executes_tools() {
351        let adapter = MockAdapter::with_stream_responses(vec![
352            vec![
353                StreamEvent::ToolCallDelta {
354                    call_id: "call-1".to_string(),
355                    delta: json!({"name":"time.now","arguments":{"timezone":"UTC"}}),
356                },
357                StreamEvent::Done,
358            ],
359            vec![
360                StreamEvent::TextDelta {
361                    delta: "Current UTC time is 12:00".to_string(),
362                },
363                StreamEvent::Done,
364            ],
365        ]);
366
367        let client = Client::new(Arc::new(adapter));
368        let result = client
369            .chat_with_tools_streaming(base_request(), &EchoTools, ToolLoopOptions::default())
370            .await
371            .unwrap();
372
373        assert_eq!(
374            result.final_response.output_text,
375            "Current UTC time is 12:00"
376        );
377        assert_eq!(result.tool_invocations.len(), 1);
378        assert_eq!(result.iterations, 2);
379    }
380
381    #[tokio::test]
382    async fn chat_with_tools_honors_max_iterations() {
383        let adapter = MockAdapter::with_chat_responses(vec![ChatResponse {
384            id: "1".to_string(),
385            model: "mock-model".to_string(),
386            output_text: "".to_string(),
387            tool_calls: vec![ToolCall {
388                id: "call-1".to_string(),
389                name: "loop.forever".to_string(),
390                arguments: json!({}),
391            }],
392            usage: None,
393        }]);
394
395        let client = Client::new(Arc::new(adapter));
396        let err = client
397            .chat_with_tools(
398                base_request(),
399                &EchoTools,
400                ToolLoopOptions { max_iterations: 1 },
401            )
402            .await
403            .unwrap_err();
404
405        assert!(matches!(err, ForgeError::Provider(_)));
406    }
407}