sayr_engine/
llm.rs

1//! Language model implementations and abstractions.
2#![allow(dead_code)]
3
4use std::collections::{HashMap, VecDeque};
5use std::sync::{Arc, Mutex};
6use std::time::Duration;
7
8use async_trait::async_trait;
9use futures::StreamExt;
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12
13use crate::config::ModelConfig;
14use crate::error::{AgnoError, Result};
15use crate::message::{Message, Role, ToolCall};
16use crate::tool::ToolDescription;
17
18/// Result of a chat completion request.
19#[derive(Debug, Clone, PartialEq, Serialize)]
20pub struct ModelCompletion {
21    pub content: Option<String>,
22    pub tool_calls: Vec<ToolCall>,
23}
24
25/// Minimal abstraction around a chat completion provider.
26#[async_trait]
27pub trait LanguageModel: Send + Sync {
28    async fn complete_chat(
29        &self,
30        messages: &[Message],
31        tools: &[ToolDescription],
32        stream: bool,
33    ) -> Result<ModelCompletion>;
34}
35
36fn coalesce_error(status: reqwest::StatusCode, body: &str, provider: &str) -> AgnoError {
37    if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
38        return AgnoError::LanguageModel(format!("{provider} rate limit exceeded: {body}"));
39    }
40    AgnoError::LanguageModel(format!("{provider} request failed with {}: {body}", status))
41}
42
43fn serialize_tool_arguments(args: &Value) -> String {
44    serde_json::to_string(args).unwrap_or_else(|_| args.to_string())
45}
46
47#[derive(Clone)]
48pub struct OpenAIClient {
49    http: reqwest::Client,
50    model: String,
51    api_key: String,
52    base_url: String,
53    organization: Option<String>,
54}
55
56impl OpenAIClient {
57    pub fn new(api_key: impl Into<String>) -> Self {
58        Self {
59            http: reqwest::Client::new(),
60            model: "gpt-4-turbo-preview".to_string(),
61            api_key: api_key.into(),
62            base_url: "https://api.openai.com/v1".to_string(),
63            organization: None,
64        }
65    }
66
67    pub fn from_env() -> Result<Self> {
68        let api_key = std::env::var("OPENAI_API_KEY")
69            .map_err(|_| AgnoError::LanguageModel("OPENAI_API_KEY not found".into()))?;
70        Ok(Self::new(api_key))
71    }
72
73    pub fn with_model(mut self, model: impl Into<String>) -> Self {
74        self.model = model.into();
75        self
76    }
77
78    pub fn from_config(cfg: &ModelConfig) -> Result<Self> {
79        let api_key = cfg
80            .openai
81            .api_key
82            .clone()
83            .or_else(|| cfg.api_key.clone())
84            .ok_or_else(|| {
85                AgnoError::LanguageModel("missing OpenAI API key in model config".into())
86            })?;
87        let base_url = cfg
88            .openai
89            .endpoint
90            .clone()
91            .or_else(|| cfg.base_url.clone())
92            .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
93        Ok(Self {
94            http: reqwest::Client::builder()
95                .timeout(Duration::from_secs(60))
96                .build()
97                .map_err(|err| AgnoError::LanguageModel(format!("http client error: {err}")))?,
98            model: cfg.model.clone(),
99            api_key,
100            base_url,
101            organization: cfg
102                .openai
103                .organization
104                .clone()
105                .or_else(|| cfg.organization.clone()),
106        })
107    }
108
109    fn to_openai_messages(&self, messages: &[Message]) -> Vec<OpenAiMessage> {
110        let mut built = Vec::new();
111        for message in messages {
112            let role = match message.role {
113                Role::System => "system",
114                Role::User => "user",
115                Role::Assistant => "assistant",
116                Role::Tool => "tool",
117            }
118            .to_string();
119
120            let mut tool_calls = None;
121            if let Some(call) = &message.tool_call {
122                tool_calls = Some(vec![OpenAiToolCall {
123                    id: call.id.clone(),
124                    r#type: "function".to_string(),
125                    function: OpenAiFunctionCall {
126                        name: call.name.clone(),
127                        arguments: serialize_tool_arguments(&call.arguments),
128                    },
129                }]);
130            }
131
132            let content = if message.role == Role::Tool {
133                message
134                    .tool_result
135                    .as_ref()
136                    .map(|result| serialize_tool_arguments(&result.output))
137                    .or_else(|| Some(message.content.clone()))
138            } else {
139                Some(message.content.clone())
140            };
141
142            let tool_call_id = message
143                .tool_result
144                .as_ref()
145                .and_then(|result| result.tool_call_id.clone());
146
147            built.push(OpenAiMessage {
148                role,
149                content,
150                tool_call_id,
151                tool_calls,
152            });
153        }
154        built
155    }
156
157    fn to_openai_tools(&self, tools: &[ToolDescription]) -> Option<Vec<OpenAiTool>> {
158        if tools.is_empty() {
159            return None;
160        }
161
162        Some(
163            tools
164                .iter()
165                .map(|tool| OpenAiTool {
166                    r#type: "function".to_string(),
167                    function: OpenAiFunction {
168                        name: tool.name.clone(),
169                        description: Some(tool.description.clone()),
170                        parameters: tool.parameters.clone(),
171                    },
172                })
173                .collect(),
174        )
175    }
176}
177
178#[async_trait]
179impl LanguageModel for OpenAIClient {
180    async fn complete_chat(
181        &self,
182        messages: &[Message],
183        tools: &[ToolDescription],
184        stream: bool,
185    ) -> Result<ModelCompletion> {
186        let payload = json!({
187            "model": self.model,
188            "messages": self.to_openai_messages(messages),
189            "tools": self.to_openai_tools(tools),
190            "tool_choice": if tools.is_empty() { Value::Null } else { Value::String("auto".to_string()) },
191            "stream": stream,
192        });
193
194        let mut builder = self
195            .http
196            .post(format!("{}/chat/completions", self.base_url))
197            .header(
198                reqwest::header::AUTHORIZATION,
199                format!("Bearer {}", self.api_key),
200            );
201        if let Some(org) = &self.organization {
202            builder = builder.header("OpenAI-Organization", org);
203        }
204        let resp = builder
205            .json(&payload)
206            .send()
207            .await
208            .map_err(|err| AgnoError::LanguageModel(format!("OpenAI request error: {err}")))?;
209
210        if !resp.status().is_success() {
211            let status = resp.status();
212            let body = resp.text().await.unwrap_or_default();
213            return Err(coalesce_error(status, &body, "openai"));
214        }
215
216        if stream {
217            let mut content = String::new();
218            let mut tool_calls: HashMap<String, OpenAiToolCallState> = HashMap::new();
219            let mut stream = resp.bytes_stream();
220            while let Some(chunk) = stream.next().await {
221                let chunk = chunk.map_err(|err| {
222                    AgnoError::LanguageModel(format!("OpenAI stream error: {err}"))
223                })?;
224                let text = String::from_utf8_lossy(&chunk);
225                for line in text.lines() {
226                    if !line.starts_with("data: ") {
227                        continue;
228                    }
229                    let data = line.trim_start_matches("data: ").trim();
230                    if data == "[DONE]" {
231                        continue;
232                    }
233                    let parsed: OpenAiStreamChunk = serde_json::from_str(data).map_err(|err| {
234                        AgnoError::LanguageModel(format!(
235                            "OpenAI stream parse error `{data}`: {err}"
236                        ))
237                    })?;
238
239                    for choice in parsed.choices {
240                        if let Some(delta_content) = choice.delta.content {
241                            content.push_str(&delta_content);
242                        }
243                        if let Some(calls) = choice.delta.tool_calls {
244                            for delta_call in calls {
245                                let id = delta_call
246                                    .id
247                                    .clone()
248                                    .unwrap_or_else(|| format!("call_{}", tool_calls.len()));
249                                let state = tool_calls.entry(id.clone()).or_default();
250                                if let Some(function) = delta_call.function {
251                                    if let Some(name) = function.name {
252                                        state.name = Some(name);
253                                    }
254                                    if let Some(args) = function.arguments {
255                                        state.arguments.push_str(&args);
256                                    }
257                                }
258                                state.id = Some(id);
259                            }
260                        }
261                    }
262                }
263            }
264
265            let calls: Vec<ToolCall> = tool_calls
266                .into_values()
267                .filter_map(|state| {
268                    let name = state.name?;
269                    let args = serde_json::from_str(&state.arguments)
270                        .unwrap_or_else(|_| Value::String(state.arguments.clone()));
271                    Some(ToolCall {
272                        id: state.id,
273                        name,
274                        arguments: args,
275                    })
276                })
277                .collect();
278
279            return Ok(ModelCompletion {
280                content: if content.is_empty() {
281                    None
282                } else {
283                    Some(content)
284                },
285                tool_calls: calls,
286            });
287        }
288
289        let body: OpenAiResponse = resp.json().await.map_err(|err| {
290            AgnoError::LanguageModel(format!("OpenAI response parse error: {err}"))
291        })?;
292
293        let first = body
294            .choices
295            .into_iter()
296            .next()
297            .ok_or_else(|| AgnoError::LanguageModel("OpenAI returned no choices".into()))?;
298
299        let mut tool_calls = Vec::new();
300        if let Some(calls) = first.message.tool_calls {
301            for call in calls {
302                let args = serde_json::from_str(&call.function.arguments)
303                    .unwrap_or_else(|_| Value::String(call.function.arguments.clone()));
304                tool_calls.push(ToolCall {
305                    id: call.id,
306                    name: call.function.name,
307                    arguments: args,
308                });
309            }
310        }
311
312        Ok(ModelCompletion {
313            content: first.message.content,
314            tool_calls,
315        })
316    }
317}
318
319#[derive(Clone)]
320pub struct AnthropicClient {
321    http: reqwest::Client,
322    model: String,
323    api_key: String,
324    endpoint: String,
325}
326
327impl AnthropicClient {
328    pub fn from_config(cfg: &ModelConfig) -> Result<Self> {
329        let api_key = cfg
330            .anthropic
331            .api_key
332            .clone()
333            .or_else(|| cfg.api_key.clone())
334            .ok_or_else(|| {
335                AgnoError::LanguageModel("missing Anthropic API key in model config".into())
336            })?;
337        let endpoint = cfg
338            .anthropic
339            .endpoint
340            .clone()
341            .unwrap_or_else(|| "https://api.anthropic.com/v1/messages".to_string());
342        Ok(Self {
343            http: reqwest::Client::builder()
344                .timeout(Duration::from_secs(60))
345                .build()
346                .map_err(|err| AgnoError::LanguageModel(format!("http client error: {err}")))?,
347            model: cfg.model.clone(),
348            api_key,
349            endpoint,
350        })
351    }
352
353    fn to_messages(&self, messages: &[Message]) -> Vec<AnthropicMessage> {
354        messages
355            .iter()
356            .filter_map(|message| match message.role {
357                Role::System => None,
358                Role::User | Role::Assistant | Role::Tool => Some(AnthropicMessage {
359                    role: match message.role {
360                        Role::User => "user",
361                        Role::Assistant | Role::Tool => "assistant",
362                        Role::System => unreachable!(),
363                    }
364                    .to_string(),
365                    content: vec![AnthropicContentBlock {
366                        r#type: "text".to_string(),
367                        text: Some(message.content.clone()),
368                        name: None,
369                        input_schema: None,
370                    }],
371                }),
372            })
373            .collect()
374    }
375
376    fn to_tools(&self, tools: &[ToolDescription]) -> Option<Vec<AnthropicTool>> {
377        if tools.is_empty() {
378            return None;
379        }
380        Some(
381            tools
382                .iter()
383                .map(|tool| AnthropicTool {
384                    name: tool.name.clone(),
385                    description: tool.description.clone(),
386                    input_schema: tool
387                        .parameters
388                        .clone()
389                        .unwrap_or_else(|| json!({"type":"object"})),
390                })
391                .collect(),
392        )
393    }
394}
395
396#[async_trait]
397impl LanguageModel for AnthropicClient {
398    async fn complete_chat(
399        &self,
400        messages: &[Message],
401        tools: &[ToolDescription],
402        stream: bool,
403    ) -> Result<ModelCompletion> {
404        let system = messages
405            .iter()
406            .find(|m| m.role == Role::System)
407            .map(|m| m.content.clone());
408        let payload = json!({
409            "model": self.model,
410            "system": system,
411            "messages": self.to_messages(messages),
412            "tools": self.to_tools(tools),
413            "stream": stream,
414        });
415
416        let resp = self
417            .http
418            .post(&self.endpoint)
419            .header("x-api-key", &self.api_key)
420            .header("anthropic-version", "2023-06-01")
421            .json(&payload)
422            .send()
423            .await
424            .map_err(|err| AgnoError::LanguageModel(format!("Anthropic request error: {err}")))?;
425
426        if !resp.status().is_success() {
427            let status = resp.status();
428            let body = resp.text().await.unwrap_or_default();
429            return Err(coalesce_error(status, &body, "anthropic"));
430        }
431
432        if stream {
433            let mut content = String::new();
434            let mut stream = resp.bytes_stream();
435            while let Some(chunk) = stream.next().await {
436                let chunk = chunk.map_err(|err| {
437                    AgnoError::LanguageModel(format!("Anthropic stream error: {err}"))
438                })?;
439                let text = String::from_utf8_lossy(&chunk);
440                for line in text.lines() {
441                    if !line.starts_with("data: ") {
442                        continue;
443                    }
444                    let data = line.trim_start_matches("data: ").trim();
445                    if data == "[DONE]" || data.is_empty() {
446                        continue;
447                    }
448                    let parsed: AnthropicStreamChunk =
449                        serde_json::from_str(data).map_err(|err| {
450                            AgnoError::LanguageModel(format!(
451                                "Anthropic stream parse error `{data}`: {err}"
452                            ))
453                        })?;
454                    if let Some(text) = parsed.delta.text {
455                        content.push_str(&text);
456                    }
457                }
458            }
459
460            return Ok(ModelCompletion {
461                content: if content.is_empty() {
462                    None
463                } else {
464                    Some(content)
465                },
466                tool_calls: Vec::new(),
467            });
468        }
469
470        let parsed: AnthropicResponse = resp.json().await.map_err(|err| {
471            AgnoError::LanguageModel(format!("Anthropic response parse error: {err}"))
472        })?;
473
474        let content = parsed
475            .content
476            .iter()
477            .filter_map(|block| block.text.clone())
478            .collect::<Vec<String>>()
479            .join("");
480
481        Ok(ModelCompletion {
482            content: if content.is_empty() {
483                None
484            } else {
485                Some(content)
486            },
487            tool_calls: Vec::new(),
488        })
489    }
490}
491
492#[derive(Clone)]
493pub struct GeminiClient {
494    http: reqwest::Client,
495    model: String,
496    api_key: String,
497    endpoint: String,
498}
499
500impl GeminiClient {
501    pub fn from_config(cfg: &ModelConfig) -> Result<Self> {
502        let api_key = cfg
503            .gemini
504            .api_key
505            .clone()
506            .or_else(|| cfg.api_key.clone())
507            .ok_or_else(|| {
508                AgnoError::LanguageModel("missing Gemini API key in model config".into())
509            })?;
510        let endpoint = cfg
511            .gemini
512            .endpoint
513            .clone()
514            .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string());
515        Ok(Self {
516            http: reqwest::Client::builder()
517                .timeout(Duration::from_secs(60))
518                .build()
519                .map_err(|err| AgnoError::LanguageModel(format!("http client error: {err}")))?,
520            model: cfg.model.clone(),
521            api_key,
522            endpoint,
523        })
524    }
525
526    fn to_contents(&self, messages: &[Message]) -> Vec<GeminiMessage> {
527        messages
528            .iter()
529            .filter_map(|message| {
530                let role = match message.role {
531                    Role::User => "user",
532                    Role::Assistant => "model",
533                    Role::System => "system",
534                    Role::Tool => "user",
535                };
536                Some(GeminiMessage {
537                    role: role.to_string(),
538                    parts: vec![GeminiPart {
539                        text: message.content.clone(),
540                    }],
541                })
542            })
543            .collect()
544    }
545}
546
547#[async_trait]
548impl LanguageModel for GeminiClient {
549    async fn complete_chat(
550        &self,
551        messages: &[Message],
552        _tools: &[ToolDescription],
553        _stream: bool,
554    ) -> Result<ModelCompletion> {
555        let payload = json!({
556            "contents": self.to_contents(messages),
557        });
558        let resp = self
559            .http
560            .post(format!(
561                "{}/models/{}:generateContent?key={}",
562                self.endpoint, self.model, self.api_key
563            ))
564            .json(&payload)
565            .send()
566            .await
567            .map_err(|err| AgnoError::LanguageModel(format!("Gemini request error: {err}")))?;
568
569        if !resp.status().is_success() {
570            let status = resp.status();
571            let body = resp.text().await.unwrap_or_default();
572            return Err(coalesce_error(status, &body, "gemini"));
573        }
574
575        let parsed: GeminiResponse = resp.json().await.map_err(|err| {
576            AgnoError::LanguageModel(format!("Gemini response parse error: {err}"))
577        })?;
578
579        let content = parsed
580            .candidates
581            .get(0)
582            .and_then(|cand| cand.content.parts.get(0))
583            .map(|part| part.text.clone())
584            .unwrap_or_default();
585
586        Ok(ModelCompletion {
587            content: if content.is_empty() {
588                None
589            } else {
590                Some(content)
591            },
592            tool_calls: Vec::new(),
593        })
594    }
595}
596
597#[derive(Clone)]
598pub struct CohereClient {
599    http: reqwest::Client,
600    model: String,
601    api_key: String,
602    endpoint: String,
603}
604
605impl CohereClient {
606    pub fn new(api_key: impl Into<String>) -> Self {
607        Self {
608            http: reqwest::Client::builder()
609                .timeout(Duration::from_secs(60))
610                .build()
611                .expect("failed to build http client"),
612            model: "command-a-03-2025".to_string(),
613            api_key: api_key.into(),
614            endpoint: "https://api.cohere.ai/v2/chat".to_string(),
615        }
616    }
617
618    pub fn with_model(mut self, model: impl Into<String>) -> Self {
619        self.model = model.into();
620        self
621    }
622
623    pub fn from_config(cfg: &ModelConfig) -> Result<Self> {
624        let api_key = cfg
625            .cohere
626            .api_key
627            .clone()
628            .or_else(|| cfg.api_key.clone())
629            .ok_or_else(|| {
630                AgnoError::LanguageModel("missing Cohere API key in model config".into())
631            })?;
632        let endpoint = cfg
633            .cohere
634            .endpoint
635            .clone()
636            .unwrap_or_else(|| "https://api.cohere.ai/v2/chat".to_string());
637        Ok(Self {
638            http: reqwest::Client::builder()
639                .timeout(Duration::from_secs(60))
640                .build()
641                .map_err(|err| AgnoError::LanguageModel(format!("http client error: {err}")))?,
642            model: cfg.model.clone(),
643            api_key,
644            endpoint,
645        })
646    }
647
648    fn to_messages(&self, messages: &[Message]) -> Vec<CohereMessage> {
649        messages
650            .iter()
651            .map(|message| {
652                let role = match message.role {
653                    Role::System => "system",
654                    Role::User => "user",
655                    Role::Assistant => "assistant",
656                    Role::Tool => "tool",
657                };
658                CohereMessage {
659                    role: role.to_string(),
660                    content: message.content.clone(),
661                }
662            })
663            .collect()
664    }
665
666    fn to_tools(&self, tools: &[ToolDescription]) -> Option<Vec<CohereTool>> {
667        if tools.is_empty() {
668            return None;
669        }
670        Some(
671            tools
672                .iter()
673                .map(|tool| CohereTool {
674                    r#type: "function".to_string(),
675                    function: CohereFunction {
676                        name: tool.name.clone(),
677                        description: Some(tool.description.clone()),
678                        parameters: tool.parameters.clone(),
679                    },
680                })
681                .collect(),
682        )
683    }
684}
685
686#[async_trait]
687impl LanguageModel for CohereClient {
688    async fn complete_chat(
689        &self,
690        messages: &[Message],
691        tools: &[ToolDescription],
692        stream: bool,
693    ) -> Result<ModelCompletion> {
694        let payload = json!({
695            "model": self.model,
696            "messages": self.to_messages(messages),
697            "tools": self.to_tools(tools),
698            "stream": stream,
699        });
700
701        let resp = self
702            .http
703            .post(&self.endpoint)
704            .header("Authorization", format!("Bearer {}", self.api_key))
705            .header("Content-Type", "application/json")
706            .json(&payload)
707            .send()
708            .await
709            .map_err(|err| AgnoError::LanguageModel(format!("Cohere request error: {err}")))?;
710
711        if !resp.status().is_success() {
712            let status = resp.status();
713            let body = resp.text().await.unwrap_or_default();
714            return Err(coalesce_error(status, &body, "cohere"));
715        }
716
717        if stream {
718            let mut content = String::new();
719            let tool_calls_map: HashMap<String, OpenAiToolCallState> = HashMap::new();
720            let mut stream = resp.bytes_stream();
721            while let Some(chunk) = stream.next().await {
722                let chunk = chunk.map_err(|err| {
723                    AgnoError::LanguageModel(format!("Cohere stream error: {err}"))
724                })?;
725                let text = String::from_utf8_lossy(&chunk);
726                for line in text.lines() {
727                    if !line.starts_with("data: ") {
728                        continue;
729                    }
730                    let data = line.trim_start_matches("data: ").trim();
731                    if data == "[DONE]" || data.is_empty() {
732                        continue;
733                    }
734                    if let Ok(parsed) = serde_json::from_str::<CohereStreamChunk>(data) {
735                        if let Some(delta) = parsed.delta {
736                            if let Some(msg) = delta.message {
737                                if let Some(c) = msg.content {
738                                    if let Some(text_content) = c.get("text") {
739                                        if let Some(t) = text_content.as_str() {
740                                            content.push_str(t);
741                                        }
742                                    }
743                                }
744                            }
745                        }
746                    }
747                }
748            }
749
750            let calls: Vec<ToolCall> = tool_calls_map
751                .into_values()
752                .filter_map(|state| {
753                    let name = state.name?;
754                    let args = serde_json::from_str(&state.arguments)
755                        .unwrap_or_else(|_| Value::String(state.arguments.clone()));
756                    Some(ToolCall {
757                        id: state.id,
758                        name,
759                        arguments: args,
760                    })
761                })
762                .collect();
763
764            return Ok(ModelCompletion {
765                content: if content.is_empty() {
766                    None
767                } else {
768                    Some(content)
769                },
770                tool_calls: calls,
771            });
772        }
773
774        let body: CohereResponse = resp.json().await.map_err(|err| {
775            AgnoError::LanguageModel(format!("Cohere response parse error: {err}"))
776        })?;
777
778        let content = body.message.and_then(|m| {
779            m.content.and_then(|c| {
780                if let Some(arr) = c.as_array() {
781                    let mut text = String::new();
782                    for item in arr {
783                        // Check if type is text (or implicit if just text field exists)
784                        if let Some(t) = item.get("text").and_then(|v| v.as_str()) {
785                            text.push_str(t);
786                        }
787                    }
788                    if text.is_empty() { None } else { Some(text) }
789                } else {
790                     c.get("text").and_then(|v| v.as_str().map(|s| s.to_string()))
791                }
792            })
793        });
794
795        let mut tool_calls = Vec::new();
796        if let Some(calls) = body.tool_calls {
797            for call in calls {
798                let args = serde_json::from_str(&call.function.arguments)
799                    .unwrap_or_else(|_| Value::String(call.function.arguments.clone()));
800                tool_calls.push(ToolCall {
801                    id: call.id,
802                    name: call.function.name,
803                    arguments: args,
804                });
805            }
806        }
807
808        Ok(ModelCompletion {
809            content,
810            tool_calls,
811        })
812    }
813}
814
815// ─────────────────────────────────────────────────────────────────────────────
816// Groq Client (OpenAI-compatible API)
817// ─────────────────────────────────────────────────────────────────────────────
818
819/// Groq client - uses OpenAI-compatible API with Groq's endpoint.
820/// Default model: llama-3.3-70b-versatile
821#[derive(Clone)]
822pub struct GroqClient {
823    http: reqwest::Client,
824    model: String,
825    api_key: String,
826    base_url: String,
827}
828
829impl GroqClient {
830    pub fn new(api_key: impl Into<String>) -> Self {
831        Self {
832            http: reqwest::Client::builder()
833                .timeout(Duration::from_secs(120))
834                .build()
835                .expect("failed to build http client"),
836            model: "llama-3.3-70b-versatile".to_string(),
837            api_key: api_key.into(),
838            base_url: "https://api.groq.com/openai/v1".to_string(),
839        }
840    }
841
842    pub fn with_model(mut self, model: impl Into<String>) -> Self {
843        self.model = model.into();
844        self
845    }
846
847    pub fn from_env() -> Result<Self> {
848        let api_key = std::env::var("GROQ_API_KEY")
849            .map_err(|_| AgnoError::LanguageModel("GROQ_API_KEY not set".into()))?;
850        Ok(Self::new(api_key))
851    }
852}
853
854#[async_trait]
855impl LanguageModel for GroqClient {
856    async fn complete_chat(
857        &self,
858        messages: &[Message],
859        tools: &[ToolDescription],
860        stream: bool,
861    ) -> Result<ModelCompletion> {
862        // Convert messages to OpenAI format
863        let oai_messages: Vec<Value> = messages
864            .iter()
865            .map(|m| {
866                let role = match m.role {
867                    Role::System => "system",
868                    Role::User => "user",
869                    Role::Assistant => "assistant",
870                    Role::Tool => "tool",
871                };
872                let mut msg = json!({
873                    "role": role,
874                    "content": m.content.clone()
875                });
876                if let Some(ref result) = m.tool_result {
877                    if let Some(ref call_id) = result.tool_call_id {
878                        msg["tool_call_id"] = json!(call_id);
879                    }
880                }
881                msg
882            })
883            .collect();
884
885        let mut body = json!({
886            "model": self.model,
887            "messages": oai_messages,
888            "stream": stream
889        });
890
891        if !tools.is_empty() {
892            let oai_tools: Vec<Value> = tools
893                .iter()
894                .map(|t| {
895                    json!({
896                        "type": "function",
897                        "function": {
898                            "name": t.name,
899                            "description": t.description,
900                            "parameters": t.parameters
901                        }
902                    })
903                })
904                .collect();
905            body["tools"] = json!(oai_tools);
906        }
907
908        let resp = self
909            .http
910            .post(format!("{}/chat/completions", self.base_url))
911            .header("Authorization", format!("Bearer {}", self.api_key))
912            .header("Content-Type", "application/json")
913            .json(&body)
914            .send()
915            .await
916            .map_err(|e| AgnoError::LanguageModel(format!("Groq request failed: {e}")))?;
917
918        let status = resp.status();
919        if !status.is_success() {
920            let body = resp.text().await.unwrap_or_default();
921            return Err(coalesce_error(status, &body, "Groq"));
922        }
923
924        let json: Value = resp
925            .json()
926            .await
927            .map_err(|e| AgnoError::LanguageModel(format!("Groq parse error: {e}")))?;
928
929        let choice = &json["choices"][0]["message"];
930        let content = choice["content"].as_str().map(String::from);
931
932        let mut tool_calls = Vec::new();
933        if let Some(calls) = choice["tool_calls"].as_array() {
934            for call in calls {
935                let name = call["function"]["name"].as_str().unwrap_or("").to_string();
936                let args_str = call["function"]["arguments"].as_str().unwrap_or("{}");
937                let args: Value = serde_json::from_str(args_str).unwrap_or(json!({}));
938                tool_calls.push(ToolCall {
939                    id: call["id"].as_str().map(String::from),
940                    name,
941                    arguments: args,
942                });
943            }
944        }
945
946        Ok(ModelCompletion { content, tool_calls })
947    }
948}
949
950// ─────────────────────────────────────────────────────────────────────────────
951// Ollama Client (Local LLM)
952// ─────────────────────────────────────────────────────────────────────────────
953
954/// Ollama client for local LLM inference.
955/// Default model: llama3.1
956#[derive(Clone)]
957pub struct OllamaClient {
958    http: reqwest::Client,
959    model: String,
960    base_url: String,
961}
962
963impl OllamaClient {
964    pub fn new() -> Self {
965        Self {
966            http: reqwest::Client::builder()
967                .timeout(Duration::from_secs(300)) // Local models can be slow
968                .build()
969                .expect("failed to build http client"),
970            model: "llama3.1".to_string(),
971            base_url: "http://localhost:11434".to_string(),
972        }
973    }
974
975    pub fn with_model(mut self, model: impl Into<String>) -> Self {
976        self.model = model.into();
977        self
978    }
979
980    pub fn with_host(mut self, host: impl Into<String>) -> Self {
981        self.base_url = host.into();
982        self
983    }
984
985    pub fn from_env() -> Self {
986        let mut client = Self::new();
987        if let Ok(host) = std::env::var("OLLAMA_HOST") {
988            client.base_url = host;
989        }
990        if let Ok(model) = std::env::var("OLLAMA_MODEL") {
991            client.model = model;
992        }
993        client
994    }
995}
996
997impl Default for OllamaClient {
998    fn default() -> Self {
999        Self::new()
1000    }
1001}
1002
1003#[async_trait]
1004impl LanguageModel for OllamaClient {
1005    async fn complete_chat(
1006        &self,
1007        messages: &[Message],
1008        tools: &[ToolDescription],
1009        _stream: bool,
1010    ) -> Result<ModelCompletion> {
1011        // Convert messages to Ollama format
1012        let ollama_messages: Vec<Value> = messages
1013            .iter()
1014            .map(|m| {
1015                let role = match m.role {
1016                    Role::System => "system",
1017                    Role::User => "user",
1018                    Role::Assistant => "assistant",
1019                    Role::Tool => "tool",
1020                };
1021                json!({
1022                    "role": role,
1023                    "content": m.content.clone()
1024                })
1025            })
1026            .collect();
1027
1028        let mut body = json!({
1029            "model": self.model,
1030            "messages": ollama_messages,
1031            "stream": false
1032        });
1033
1034        if !tools.is_empty() {
1035            let ollama_tools: Vec<Value> = tools
1036                .iter()
1037                .map(|t| {
1038                    json!({
1039                        "type": "function",
1040                        "function": {
1041                            "name": t.name,
1042                            "description": t.description,
1043                            "parameters": t.parameters
1044                        }
1045                    })
1046                })
1047                .collect();
1048            body["tools"] = json!(ollama_tools);
1049        }
1050
1051        let resp = self
1052            .http
1053            .post(format!("{}/api/chat", self.base_url))
1054            .header("Content-Type", "application/json")
1055            .json(&body)
1056            .send()
1057            .await
1058            .map_err(|e| AgnoError::LanguageModel(format!("Ollama request failed: {e}")))?;
1059
1060        let status = resp.status();
1061        if !status.is_success() {
1062            let body = resp.text().await.unwrap_or_default();
1063            return Err(coalesce_error(status, &body, "Ollama"));
1064        }
1065
1066        let json: Value = resp
1067            .json()
1068            .await
1069            .map_err(|e| AgnoError::LanguageModel(format!("Ollama parse error: {e}")))?;
1070
1071        let message = &json["message"];
1072        let content = message["content"].as_str().map(String::from);
1073
1074        let mut tool_calls = Vec::new();
1075        if let Some(calls) = message["tool_calls"].as_array() {
1076            for call in calls {
1077                let func = &call["function"];
1078                let name = func["name"].as_str().unwrap_or("").to_string();
1079                let args = func["arguments"].clone();
1080                tool_calls.push(ToolCall {
1081                    id: None,
1082                    name,
1083                    arguments: args,
1084                });
1085            }
1086        }
1087
1088        Ok(ModelCompletion { content, tool_calls })
1089    }
1090}
1091
1092// ─────────────────────────────────────────────────────────────────────────────
1093// Mistral AI Client
1094// ─────────────────────────────────────────────────────────────────────────────
1095
1096/// Mistral AI client using their OpenAI-compatible API.
1097/// Default model: mistral-large-latest
1098#[derive(Clone)]
1099pub struct MistralClient {
1100    http: reqwest::Client,
1101    model: String,
1102    api_key: String,
1103    base_url: String,
1104}
1105
1106impl MistralClient {
1107    pub fn new(api_key: impl Into<String>) -> Self {
1108        Self {
1109            http: reqwest::Client::builder()
1110                .timeout(Duration::from_secs(120))
1111                .build()
1112                .expect("failed to build http client"),
1113            model: "mistral-large-latest".to_string(),
1114            api_key: api_key.into(),
1115            base_url: "https://api.mistral.ai/v1".to_string(),
1116        }
1117    }
1118
1119    pub fn with_model(mut self, model: impl Into<String>) -> Self {
1120        self.model = model.into();
1121        self
1122    }
1123
1124    pub fn from_env() -> Result<Self> {
1125        let api_key = std::env::var("MISTRAL_API_KEY")
1126            .map_err(|_| AgnoError::LanguageModel("MISTRAL_API_KEY not set".into()))?;
1127        Ok(Self::new(api_key))
1128    }
1129}
1130
1131#[async_trait]
1132impl LanguageModel for MistralClient {
1133    async fn complete_chat(
1134        &self,
1135        messages: &[Message],
1136        tools: &[ToolDescription],
1137        stream: bool,
1138    ) -> Result<ModelCompletion> {
1139        // Convert messages to Mistral format (OpenAI-compatible)
1140        let mistral_messages: Vec<Value> = messages
1141            .iter()
1142            .map(|m| {
1143                let role = match m.role {
1144                    Role::System => "system",
1145                    Role::User => "user",
1146                    Role::Assistant => "assistant",
1147                    Role::Tool => "tool",
1148                };
1149
1150                let mut msg = json!({
1151                    "role": role,
1152                    "content": m.content.clone()
1153                });
1154
1155                // Add tool_call_id for tool responses
1156                if m.role == Role::Tool {
1157                    if let Some(ref tc) = m.tool_call {
1158                        if let Some(ref id) = tc.id {
1159                            msg["tool_call_id"] = json!(id);
1160                        }
1161                    }
1162                }
1163
1164                // Add tool_calls for assistant messages
1165                if let Some(ref tc) = m.tool_call {
1166                    if m.role == Role::Assistant {
1167                        msg["tool_calls"] = json!([{
1168                            "id": tc.id.clone().unwrap_or_default(),
1169                            "type": "function",
1170                            "function": {
1171                                "name": tc.name,
1172                                "arguments": serialize_tool_arguments(&tc.arguments)
1173                            }
1174                        }]);
1175                        msg["content"] = json!(null);
1176                    }
1177                }
1178
1179                msg
1180            })
1181            .collect();
1182
1183        let mut body = json!({
1184            "model": self.model,
1185            "messages": mistral_messages,
1186            "stream": stream
1187        });
1188
1189        if !tools.is_empty() {
1190            let mistral_tools: Vec<Value> = tools
1191                .iter()
1192                .map(|t| {
1193                    json!({
1194                        "type": "function",
1195                        "function": {
1196                            "name": t.name,
1197                            "description": t.description,
1198                            "parameters": t.parameters.clone().unwrap_or(json!({"type": "object", "properties": {}}))
1199                        }
1200                    })
1201                })
1202                .collect();
1203            body["tools"] = json!(mistral_tools);
1204            body["tool_choice"] = json!("auto");
1205        }
1206
1207        let resp = self
1208            .http
1209            .post(format!("{}/chat/completions", self.base_url))
1210            .header("Authorization", format!("Bearer {}", self.api_key))
1211            .header("Content-Type", "application/json")
1212            .json(&body)
1213            .send()
1214            .await
1215            .map_err(|e| AgnoError::LanguageModel(format!("Mistral request failed: {e}")))?;
1216
1217        let status = resp.status();
1218        if !status.is_success() {
1219            let body = resp.text().await.unwrap_or_default();
1220            return Err(coalesce_error(status, &body, "Mistral"));
1221        }
1222
1223        // Parse response (OpenAI-compatible format)
1224        let json: Value = resp
1225            .json()
1226            .await
1227            .map_err(|e| AgnoError::LanguageModel(format!("Mistral parse error: {e}")))?;
1228
1229        let choice = json["choices"]
1230            .as_array()
1231            .and_then(|c| c.first())
1232            .ok_or_else(|| AgnoError::LanguageModel("Mistral returned no choices".into()))?;
1233
1234        let message = &choice["message"];
1235        let content = message["content"].as_str().map(String::from);
1236
1237        let mut tool_calls = Vec::new();
1238        if let Some(calls) = message["tool_calls"].as_array() {
1239            for call in calls {
1240                let id = call["id"].as_str().map(String::from);
1241                let func = &call["function"];
1242                let name = func["name"].as_str().unwrap_or("").to_string();
1243                let args_str = func["arguments"].as_str().unwrap_or("{}");
1244                let args: Value = serde_json::from_str(args_str).unwrap_or(json!({}));
1245                tool_calls.push(ToolCall {
1246                    id,
1247                    name,
1248                    arguments: args,
1249                });
1250            }
1251        }
1252
1253        Ok(ModelCompletion { content, tool_calls })
1254    }
1255}
1256
1257// ─────────────────────────────────────────────────────────────────────────────
1258// Azure OpenAI Client
1259// ─────────────────────────────────────────────────────────────────────────────
1260
1261/// Azure OpenAI client for Azure-hosted models.
1262#[derive(Clone)]
1263pub struct AzureOpenAIClient {
1264    http: reqwest::Client,
1265    endpoint: String,
1266    api_key: String,
1267    deployment: String,
1268    api_version: String,
1269}
1270
1271impl AzureOpenAIClient {
1272    pub fn new(
1273        endpoint: impl Into<String>,
1274        api_key: impl Into<String>,
1275        deployment: impl Into<String>,
1276    ) -> Self {
1277        Self {
1278            http: reqwest::Client::builder()
1279                .timeout(Duration::from_secs(120))
1280                .build()
1281                .expect("failed to build http client"),
1282            endpoint: endpoint.into(),
1283            api_key: api_key.into(),
1284            deployment: deployment.into(),
1285            api_version: "2024-02-01".to_string(),
1286        }
1287    }
1288
1289    pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
1290        self.api_version = version.into();
1291        self
1292    }
1293
1294    pub fn from_env() -> Result<Self> {
1295        let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT")
1296            .map_err(|_| AgnoError::LanguageModel("AZURE_OPENAI_ENDPOINT not set".into()))?;
1297        let api_key = std::env::var("AZURE_OPENAI_API_KEY")
1298            .map_err(|_| AgnoError::LanguageModel("AZURE_OPENAI_API_KEY not set".into()))?;
1299        let deployment = std::env::var("AZURE_OPENAI_DEPLOYMENT")
1300            .unwrap_or_else(|_| "gpt-4".to_string());
1301        Ok(Self::new(endpoint, api_key, deployment))
1302    }
1303}
1304
1305#[async_trait]
1306impl LanguageModel for AzureOpenAIClient {
1307    async fn complete_chat(
1308        &self,
1309        messages: &[Message],
1310        tools: &[ToolDescription],
1311        stream: bool,
1312    ) -> Result<ModelCompletion> {
1313        // Convert messages to OpenAI format
1314        let azure_messages: Vec<Value> = messages
1315            .iter()
1316            .map(|m| {
1317                let role = match m.role {
1318                    Role::System => "system",
1319                    Role::User => "user",
1320                    Role::Assistant => "assistant",
1321                    Role::Tool => "tool",
1322                };
1323
1324                let mut msg = json!({
1325                    "role": role,
1326                    "content": m.content.clone()
1327                });
1328
1329                if m.role == Role::Tool {
1330                    if let Some(ref tc) = m.tool_call {
1331                        if let Some(ref id) = tc.id {
1332                            msg["tool_call_id"] = json!(id);
1333                        }
1334                    }
1335                }
1336
1337                if let Some(ref tc) = m.tool_call {
1338                    if m.role == Role::Assistant {
1339                        msg["tool_calls"] = json!([{
1340                            "id": tc.id.clone().unwrap_or_default(),
1341                            "type": "function",
1342                            "function": {
1343                                "name": tc.name,
1344                                "arguments": serialize_tool_arguments(&tc.arguments)
1345                            }
1346                        }]);
1347                        msg["content"] = json!(null);
1348                    }
1349                }
1350
1351                msg
1352            })
1353            .collect();
1354
1355        let mut body = json!({
1356            "messages": azure_messages,
1357            "stream": stream
1358        });
1359
1360        if !tools.is_empty() {
1361            let azure_tools: Vec<Value> = tools
1362                .iter()
1363                .map(|t| {
1364                    json!({
1365                        "type": "function",
1366                        "function": {
1367                            "name": t.name,
1368                            "description": t.description,
1369                            "parameters": t.parameters.clone().unwrap_or(json!({"type": "object", "properties": {}}))
1370                        }
1371                    })
1372                })
1373                .collect();
1374            body["tools"] = json!(azure_tools);
1375            body["tool_choice"] = json!("auto");
1376        }
1377
1378        let url = format!(
1379            "{}/openai/deployments/{}/chat/completions?api-version={}",
1380            self.endpoint, self.deployment, self.api_version
1381        );
1382
1383        let resp = self
1384            .http
1385            .post(&url)
1386            .header("api-key", &self.api_key)
1387            .header("Content-Type", "application/json")
1388            .json(&body)
1389            .send()
1390            .await
1391            .map_err(|e| AgnoError::LanguageModel(format!("Azure OpenAI request failed: {e}")))?;
1392
1393        let status = resp.status();
1394        if !status.is_success() {
1395            let body = resp.text().await.unwrap_or_default();
1396            return Err(coalesce_error(status, &body, "Azure OpenAI"));
1397        }
1398
1399        let json: Value = resp
1400            .json()
1401            .await
1402            .map_err(|e| AgnoError::LanguageModel(format!("Azure OpenAI parse error: {e}")))?;
1403
1404        let choice = json["choices"]
1405            .as_array()
1406            .and_then(|c| c.first())
1407            .ok_or_else(|| AgnoError::LanguageModel("Azure OpenAI returned no choices".into()))?;
1408
1409        let message = &choice["message"];
1410        let content = message["content"].as_str().map(String::from);
1411
1412        let mut tool_calls = Vec::new();
1413        if let Some(calls) = message["tool_calls"].as_array() {
1414            for call in calls {
1415                let id = call["id"].as_str().map(String::from);
1416                let func = &call["function"];
1417                let name = func["name"].as_str().unwrap_or("").to_string();
1418                let args_str = func["arguments"].as_str().unwrap_or("{}");
1419                let args: Value = serde_json::from_str(args_str).unwrap_or(json!({}));
1420                tool_calls.push(ToolCall {
1421                    id,
1422                    name,
1423                    arguments: args,
1424                });
1425            }
1426        }
1427
1428        Ok(ModelCompletion { content, tool_calls })
1429    }
1430}
1431
1432// ─────────────────────────────────────────────────────────────────────────────
1433// Together AI Client
1434// ─────────────────────────────────────────────────────────────────────────────
1435
1436/// Together AI client using their OpenAI-compatible API.
1437/// Default model: meta-llama/Llama-3.3-70B-Instruct-Turbo
1438#[derive(Clone)]
1439pub struct TogetherClient {
1440    http: reqwest::Client,
1441    model: String,
1442    api_key: String,
1443}
1444
1445impl TogetherClient {
1446    pub fn new(api_key: impl Into<String>) -> Self {
1447        Self {
1448            http: reqwest::Client::builder()
1449                .timeout(Duration::from_secs(120))
1450                .build()
1451                .expect("failed to build http client"),
1452            model: "meta-llama/Llama-3.3-70B-Instruct-Turbo".to_string(),
1453            api_key: api_key.into(),
1454        }
1455    }
1456
1457    pub fn with_model(mut self, model: impl Into<String>) -> Self {
1458        self.model = model.into();
1459        self
1460    }
1461
1462    pub fn from_env() -> Result<Self> {
1463        let api_key = std::env::var("TOGETHER_API_KEY")
1464            .map_err(|_| AgnoError::LanguageModel("TOGETHER_API_KEY not set".into()))?;
1465        Ok(Self::new(api_key))
1466    }
1467}
1468
1469#[async_trait]
1470impl LanguageModel for TogetherClient {
1471    async fn complete_chat(
1472        &self,
1473        messages: &[Message],
1474        tools: &[ToolDescription],
1475        stream: bool,
1476    ) -> Result<ModelCompletion> {
1477        let together_messages: Vec<Value> = messages
1478            .iter()
1479            .map(|m| {
1480                let role = match m.role {
1481                    Role::System => "system",
1482                    Role::User => "user",
1483                    Role::Assistant => "assistant",
1484                    Role::Tool => "tool",
1485                };
1486                json!({
1487                    "role": role,
1488                    "content": m.content.clone()
1489                })
1490            })
1491            .collect();
1492
1493        let mut body = json!({
1494            "model": self.model,
1495            "messages": together_messages,
1496            "stream": stream
1497        });
1498
1499        if !tools.is_empty() {
1500            let together_tools: Vec<Value> = tools
1501                .iter()
1502                .map(|t| {
1503                    json!({
1504                        "type": "function",
1505                        "function": {
1506                            "name": t.name,
1507                            "description": t.description,
1508                            "parameters": t.parameters.clone().unwrap_or(json!({"type": "object", "properties": {}}))
1509                        }
1510                    })
1511                })
1512                .collect();
1513            body["tools"] = json!(together_tools);
1514        }
1515
1516        let resp = self
1517            .http
1518            .post("https://api.together.xyz/v1/chat/completions")
1519            .header("Authorization", format!("Bearer {}", self.api_key))
1520            .header("Content-Type", "application/json")
1521            .json(&body)
1522            .send()
1523            .await
1524            .map_err(|e| AgnoError::LanguageModel(format!("Together request failed: {e}")))?;
1525
1526        let status = resp.status();
1527        if !status.is_success() {
1528            let body = resp.text().await.unwrap_or_default();
1529            return Err(coalesce_error(status, &body, "Together"));
1530        }
1531
1532        let json: Value = resp
1533            .json()
1534            .await
1535            .map_err(|e| AgnoError::LanguageModel(format!("Together parse error: {e}")))?;
1536
1537        let choice = json["choices"]
1538            .as_array()
1539            .and_then(|c| c.first())
1540            .ok_or_else(|| AgnoError::LanguageModel("Together returned no choices".into()))?;
1541
1542        let message = &choice["message"];
1543        let content = message["content"].as_str().map(String::from);
1544
1545        let mut tool_calls = Vec::new();
1546        if let Some(calls) = message["tool_calls"].as_array() {
1547            for call in calls {
1548                let id = call["id"].as_str().map(String::from);
1549                let func = &call["function"];
1550                let name = func["name"].as_str().unwrap_or("").to_string();
1551                let args_str = func["arguments"].as_str().unwrap_or("{}");
1552                let args: Value = serde_json::from_str(args_str).unwrap_or(json!({}));
1553                tool_calls.push(ToolCall {
1554                    id,
1555                    name,
1556                    arguments: args,
1557                });
1558            }
1559        }
1560
1561        Ok(ModelCompletion { content, tool_calls })
1562    }
1563}
1564
1565// ─────────────────────────────────────────────────────────────────────────────
1566// Fireworks AI Client
1567// ─────────────────────────────────────────────────────────────────────────────
1568
1569/// Fireworks AI client using their OpenAI-compatible API.
1570/// Default model: accounts/fireworks/models/llama-v3p1-70b-instruct
1571#[derive(Clone)]
1572pub struct FireworksClient {
1573    http: reqwest::Client,
1574    model: String,
1575    api_key: String,
1576}
1577
1578impl FireworksClient {
1579    pub fn new(api_key: impl Into<String>) -> Self {
1580        Self {
1581            http: reqwest::Client::builder()
1582                .timeout(Duration::from_secs(120))
1583                .build()
1584                .expect("failed to build http client"),
1585            model: "accounts/fireworks/models/llama-v3p1-70b-instruct".to_string(),
1586            api_key: api_key.into(),
1587        }
1588    }
1589
1590    pub fn with_model(mut self, model: impl Into<String>) -> Self {
1591        self.model = model.into();
1592        self
1593    }
1594
1595    pub fn from_env() -> Result<Self> {
1596        let api_key = std::env::var("FIREWORKS_API_KEY")
1597            .map_err(|_| AgnoError::LanguageModel("FIREWORKS_API_KEY not set".into()))?;
1598        Ok(Self::new(api_key))
1599    }
1600}
1601
1602#[async_trait]
1603impl LanguageModel for FireworksClient {
1604    async fn complete_chat(
1605        &self,
1606        messages: &[Message],
1607        tools: &[ToolDescription],
1608        stream: bool,
1609    ) -> Result<ModelCompletion> {
1610        let fireworks_messages: Vec<Value> = messages
1611            .iter()
1612            .map(|m| {
1613                let role = match m.role {
1614                    Role::System => "system",
1615                    Role::User => "user",
1616                    Role::Assistant => "assistant",
1617                    Role::Tool => "tool",
1618                };
1619                json!({
1620                    "role": role,
1621                    "content": m.content.clone()
1622                })
1623            })
1624            .collect();
1625
1626        let mut body = json!({
1627            "model": self.model,
1628            "messages": fireworks_messages,
1629            "stream": stream
1630        });
1631
1632        if !tools.is_empty() {
1633            let fireworks_tools: Vec<Value> = tools
1634                .iter()
1635                .map(|t| {
1636                    json!({
1637                        "type": "function",
1638                        "function": {
1639                            "name": t.name,
1640                            "description": t.description,
1641                            "parameters": t.parameters.clone().unwrap_or(json!({"type": "object", "properties": {}}))
1642                        }
1643                    })
1644                })
1645                .collect();
1646            body["tools"] = json!(fireworks_tools);
1647        }
1648
1649        let resp = self
1650            .http
1651            .post("https://api.fireworks.ai/inference/v1/chat/completions")
1652            .header("Authorization", format!("Bearer {}", self.api_key))
1653            .header("Content-Type", "application/json")
1654            .json(&body)
1655            .send()
1656            .await
1657            .map_err(|e| AgnoError::LanguageModel(format!("Fireworks request failed: {e}")))?;
1658
1659        let status = resp.status();
1660        if !status.is_success() {
1661            let body = resp.text().await.unwrap_or_default();
1662            return Err(coalesce_error(status, &body, "Fireworks"));
1663        }
1664
1665        let json: Value = resp
1666            .json()
1667            .await
1668            .map_err(|e| AgnoError::LanguageModel(format!("Fireworks parse error: {e}")))?;
1669
1670        let choice = json["choices"]
1671            .as_array()
1672            .and_then(|c| c.first())
1673            .ok_or_else(|| AgnoError::LanguageModel("Fireworks returned no choices".into()))?;
1674
1675        let message = &choice["message"];
1676        let content = message["content"].as_str().map(String::from);
1677
1678        let mut tool_calls = Vec::new();
1679        if let Some(calls) = message["tool_calls"].as_array() {
1680            for call in calls {
1681                let id = call["id"].as_str().map(String::from);
1682                let func = &call["function"];
1683                let name = func["name"].as_str().unwrap_or("").to_string();
1684                let args_str = func["arguments"].as_str().unwrap_or("{}");
1685                let args: Value = serde_json::from_str(args_str).unwrap_or(json!({}));
1686                tool_calls.push(ToolCall {
1687                    id,
1688                    name,
1689                    arguments: args,
1690                });
1691            }
1692        }
1693
1694        Ok(ModelCompletion { content, tool_calls })
1695    }
1696}
1697
1698// ─────────────────────────────────────────────────────────────────────────────
1699// AWS Bedrock Client
1700// ─────────────────────────────────────────────────────────────────────────────
1701
1702/// AWS Bedrock client.
1703/// Currently optimized for Anthropic Claude 3 models on Bedrock.
1704#[derive(Clone)]
1705pub struct AwsBedrockClient {
1706    client: std::sync::Arc<aws_sdk_bedrockruntime::Client>,
1707    model_id: String,
1708}
1709
1710impl AwsBedrockClient {
1711    pub async fn new(region: Option<String>) -> Self {
1712        let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
1713        if let Some(r) = region {
1714             loader = loader.region(aws_config::Region::new(r));
1715        }
1716        let sdk_config = loader.load().await;
1717        let client = aws_sdk_bedrockruntime::Client::new(&sdk_config);
1718        
1719        Self {
1720            client: std::sync::Arc::new(client),
1721            model_id: "anthropic.claude-3-sonnet-20240229-v1:0".to_string(),
1722        }
1723    }
1724
1725    pub fn with_model(mut self, model_id: impl Into<String>) -> Self {
1726        self.model_id = model_id.into();
1727        self
1728    }
1729}
1730
1731#[async_trait]
1732impl LanguageModel for AwsBedrockClient {
1733    async fn complete_chat(
1734        &self,
1735        messages: &[Message],
1736        tools: &[ToolDescription],
1737        _stream: bool, // Streaming not implemented for this pass
1738    ) -> Result<ModelCompletion> {
1739        // Construct Anthropic Messages API payload for Bedrock
1740        let system_prompt = messages
1741            .iter()
1742            .filter(|m| m.role == Role::System)
1743            .map(|m| m.content.clone())
1744            .collect::<Vec<_>>()
1745            .join("\n");
1746
1747        let mut bedrock_messages = Vec::new();
1748        for m in messages {
1749            if m.role == Role::System { continue; }
1750            
1751            let role = match m.role {
1752                Role::User => "user",
1753                Role::Assistant => "assistant",
1754                Role::Tool => "user", // Tool results return as user content in Anthropic API
1755                _ => "user",
1756            };
1757
1758            let content = if m.role == Role::Tool {
1759                // Handle tool results
1760                json!([{
1761                    "type": "tool_result",
1762                    "tool_use_id": m.tool_call.as_ref().and_then(|t| t.id.clone()).unwrap_or_default(),
1763                    "content": m.content
1764                }])
1765            } else if let Some(ref tc) = m.tool_call {
1766                 // Handle assistant tool use
1767                 json!([{
1768                    "type": "tool_use",
1769                    "id": tc.id.clone().unwrap_or_default(),
1770                    "name": tc.name,
1771                    "input": tc.arguments
1772                }])
1773            } else {
1774                json!(m.content)
1775            };
1776
1777            bedrock_messages.push(json!({
1778                "role": role,
1779                "content": content
1780            }));
1781        }
1782
1783        let mut body = json!({
1784            "anthropic_version": "bedrock-2023-05-31",
1785            "max_tokens": 4096,
1786            "messages": bedrock_messages
1787        });
1788
1789        if !system_prompt.is_empty() {
1790             body["system"] = json!(system_prompt);
1791        }
1792
1793        if !tools.is_empty() {
1794            let tool_defs: Vec<Value> = tools.iter().map(|t| {
1795                json!({
1796                    "name": t.name,
1797                    "description": t.description,
1798                    "input_schema": t.parameters.clone().unwrap_or(json!({"type": "object", "properties": {}}))
1799                })
1800            }).collect();
1801            body["tools"] = json!(tool_defs);
1802        }
1803
1804        let blob = aws_sdk_bedrockruntime::primitives::Blob::new(serde_json::to_vec(&body).unwrap());
1805
1806        let output = self.client
1807            .invoke_model()
1808            .model_id(&self.model_id)
1809            .body(blob)
1810            .send()
1811            .await
1812            .map_err(|e| AgnoError::LanguageModel(format!("Bedrock invocation failed: {}", e)))?;
1813
1814        let response_body: Value = serde_json::from_slice(output.body.as_ref())
1815            .map_err(|e| AgnoError::LanguageModel(format!("Failed to parse Bedrock response: {}", e)))?;
1816
1817        let mut content = None;
1818        let mut tool_calls = Vec::new();
1819
1820        if let Some(content_blocks) = response_body["content"].as_array() {
1821            let mut text_parts = Vec::new();
1822            for block in content_blocks {
1823                if block["type"] == "text" {
1824                    if let Some(text) = block["text"].as_str() {
1825                        text_parts.push(text);
1826                    }
1827                } else if block["type"] == "tool_use" {
1828                    let id = block["id"].as_str().map(String::from);
1829                    let name = block["name"].as_str().unwrap_or_default().to_string();
1830                    let args = block["input"].clone();
1831                    tool_calls.push(ToolCall { id, name, arguments: args });
1832                }
1833            }
1834            if !text_parts.is_empty() {
1835                content = Some(text_parts.join("\n"));
1836            }
1837        }
1838
1839        Ok(ModelCompletion { content, tool_calls })
1840    }
1841}
1842
1843
1844pub struct StubModel {
1845    responses: Mutex<VecDeque<String>>,
1846}
1847
1848
1849impl StubModel {
1850    pub fn new(responses: Vec<String>) -> Arc<Self> {
1851        Arc::new(Self {
1852            responses: Mutex::new(responses.into()),
1853        })
1854    }
1855}
1856
1857#[derive(Debug, Deserialize)]
1858#[serde(tag = "action", rename_all = "snake_case")]
1859enum StubDirective {
1860    Respond { content: String },
1861    CallTool { name: String, arguments: Value },
1862}
1863
1864#[async_trait]
1865impl LanguageModel for StubModel {
1866    async fn complete_chat(
1867        &self,
1868        _messages: &[Message],
1869        _tools: &[ToolDescription],
1870        _stream: bool,
1871    ) -> Result<ModelCompletion> {
1872        let mut locked = self.responses.lock().expect("stub model poisoned");
1873        let raw = locked.pop_front().ok_or_else(|| {
1874            AgnoError::LanguageModel("StubModel ran out of scripted responses".into())
1875        })?;
1876
1877        match serde_json::from_str::<StubDirective>(&raw) {
1878            Ok(StubDirective::Respond { content }) => Ok(ModelCompletion {
1879                content: Some(content),
1880                tool_calls: Vec::new(),
1881            }),
1882            Ok(StubDirective::CallTool { name, arguments }) => Ok(ModelCompletion {
1883                content: None,
1884                tool_calls: vec![ToolCall {
1885                    id: None,
1886                    name,
1887                    arguments,
1888                }],
1889            }),
1890            Err(_) => Ok(ModelCompletion {
1891                content: Some(raw),
1892                tool_calls: Vec::new(),
1893            }),
1894        }
1895    }
1896}
1897
1898#[derive(Debug, Serialize, Deserialize)]
1899struct OpenAiMessage {
1900    role: String,
1901    #[serde(skip_serializing_if = "Option::is_none")]
1902    content: Option<String>,
1903    #[serde(skip_serializing_if = "Option::is_none")]
1904    tool_call_id: Option<String>,
1905    #[serde(skip_serializing_if = "Option::is_none")]
1906    tool_calls: Option<Vec<OpenAiToolCall>>,
1907}
1908
1909#[derive(Debug, Serialize, Deserialize)]
1910struct OpenAiToolCall {
1911    #[serde(skip_serializing_if = "Option::is_none")]
1912    id: Option<String>,
1913    r#type: String,
1914    function: OpenAiFunctionCall,
1915}
1916
1917#[derive(Debug, Serialize, Deserialize)]
1918struct OpenAiFunctionCall {
1919    name: String,
1920    arguments: String,
1921}
1922
1923#[derive(Debug, Serialize, Deserialize)]
1924struct OpenAiTool {
1925    r#type: String,
1926    function: OpenAiFunction,
1927}
1928
1929#[derive(Debug, Serialize, Deserialize)]
1930struct OpenAiFunction {
1931    name: String,
1932    #[serde(skip_serializing_if = "Option::is_none")]
1933    description: Option<String>,
1934    #[serde(skip_serializing_if = "Option::is_none")]
1935    parameters: Option<Value>,
1936}
1937
1938#[derive(Debug, Deserialize)]
1939struct OpenAiResponse {
1940    choices: Vec<OpenAiChoice>,
1941}
1942
1943#[derive(Debug, Deserialize)]
1944struct OpenAiChoice {
1945    message: OpenAiChoiceMessage,
1946    #[allow(dead_code)]
1947    finish_reason: Option<String>,
1948}
1949
1950#[derive(Debug, Deserialize)]
1951struct OpenAiChoiceMessage {
1952    content: Option<String>,
1953    #[serde(default)]
1954    tool_calls: Option<Vec<OpenAiToolCall>>,
1955}
1956
1957#[derive(Default)]
1958struct OpenAiToolCallState {
1959    id: Option<String>,
1960    name: Option<String>,
1961    arguments: String,
1962}
1963
1964#[derive(Debug, Deserialize)]
1965struct OpenAiStreamChunk {
1966    choices: Vec<OpenAiDeltaChoice>,
1967}
1968
1969#[derive(Debug, Deserialize)]
1970struct OpenAiDeltaChoice {
1971    delta: OpenAiDelta,
1972    #[allow(dead_code)]
1973    finish_reason: Option<String>,
1974}
1975
1976#[derive(Debug, Deserialize)]
1977struct OpenAiDelta {
1978    content: Option<String>,
1979    #[serde(default)]
1980    tool_calls: Option<Vec<OpenAiToolCallDelta>>,
1981}
1982
1983#[derive(Debug, Deserialize)]
1984struct OpenAiToolCallDelta {
1985    id: Option<String>,
1986    #[serde(default)]
1987    function: Option<OpenAiFunctionDelta>,
1988}
1989
1990#[derive(Debug, Deserialize)]
1991struct OpenAiFunctionDelta {
1992    #[serde(default)]
1993    name: Option<String>,
1994    #[serde(default)]
1995    arguments: Option<String>,
1996}
1997
1998#[derive(Debug, Serialize, Deserialize)]
1999struct AnthropicMessage {
2000    role: String,
2001    content: Vec<AnthropicContentBlock>,
2002}
2003
2004#[derive(Debug, Serialize, Deserialize)]
2005struct AnthropicContentBlock {
2006    r#type: String,
2007    #[serde(skip_serializing_if = "Option::is_none")]
2008    text: Option<String>,
2009    #[serde(skip_serializing_if = "Option::is_none")]
2010    name: Option<String>,
2011    #[serde(skip_serializing_if = "Option::is_none")]
2012    input_schema: Option<Value>,
2013}
2014
2015#[derive(Debug, Serialize, Deserialize)]
2016struct AnthropicTool {
2017    name: String,
2018    description: String,
2019    input_schema: Value,
2020}
2021
2022#[derive(Debug, Deserialize)]
2023struct AnthropicResponse {
2024    content: Vec<AnthropicContentBlock>,
2025}
2026
2027#[derive(Debug, Deserialize)]
2028struct AnthropicStreamChunk {
2029    delta: AnthropicDelta,
2030}
2031
2032#[derive(Debug, Deserialize)]
2033struct AnthropicDelta {
2034    #[serde(default)]
2035    text: Option<String>,
2036}
2037
2038#[derive(Debug, Serialize, Deserialize)]
2039struct GeminiMessage {
2040    role: String,
2041    parts: Vec<GeminiPart>,
2042}
2043
2044#[derive(Debug, Serialize, Deserialize)]
2045struct GeminiPart {
2046    text: String,
2047}
2048
2049#[derive(Debug, Deserialize)]
2050struct GeminiResponse {
2051    candidates: Vec<GeminiCandidate>,
2052}
2053
2054#[derive(Debug, Deserialize)]
2055struct GeminiCandidate {
2056    content: GeminiCandidateContent,
2057}
2058
2059#[derive(Debug, Deserialize)]
2060struct GeminiCandidateContent {
2061    parts: Vec<GeminiPart>,
2062}
2063
2064// Cohere data structures
2065#[derive(Debug, Serialize, Deserialize)]
2066struct CohereMessage {
2067    role: String,
2068    content: String,
2069}
2070
2071#[derive(Debug, Serialize, Deserialize)]
2072struct CohereTool {
2073    r#type: String,
2074    function: CohereFunction,
2075}
2076
2077#[derive(Debug, Serialize, Deserialize)]
2078struct CohereFunction {
2079    name: String,
2080    #[serde(skip_serializing_if = "Option::is_none")]
2081    description: Option<String>,
2082    #[serde(skip_serializing_if = "Option::is_none")]
2083    parameters: Option<Value>,
2084}
2085
2086#[derive(Debug, Deserialize)]
2087struct CohereResponse {
2088    #[serde(default)]
2089    message: Option<CohereResponseMessage>,
2090    #[serde(default)]
2091    tool_calls: Option<Vec<CohereToolCall>>,
2092}
2093
2094#[derive(Debug, Deserialize)]
2095struct CohereResponseMessage {
2096    #[serde(default)]
2097    content: Option<Value>,
2098}
2099
2100#[derive(Debug, Deserialize)]
2101struct CohereToolCall {
2102    #[serde(default)]
2103    id: Option<String>,
2104    function: CohereFunctionCall,
2105}
2106
2107#[derive(Debug, Deserialize)]
2108struct CohereFunctionCall {
2109    name: String,
2110    arguments: String,
2111}
2112
2113#[derive(Debug, Deserialize)]
2114struct CohereStreamChunk {
2115    #[serde(default)]
2116    delta: Option<CohereDelta>,
2117}
2118
2119#[derive(Debug, Deserialize)]
2120struct CohereDelta {
2121    #[serde(default)]
2122    message: Option<CohereResponseMessage>,
2123}
2124