1#![allow(dead_code)]
3
4use std::collections::{HashMap, VecDeque};
5use std::sync::{Arc, Mutex};
6use std::time::Duration;
7
8use async_trait::async_trait;
9use futures::StreamExt;
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12
13use crate::config::ModelConfig;
14use crate::error::{AgnoError, Result};
15use crate::message::{Message, Role, ToolCall};
16use crate::tool::ToolDescription;
17
18#[derive(Debug, Clone, PartialEq, Serialize)]
20pub struct ModelCompletion {
21 pub content: Option<String>,
22 pub tool_calls: Vec<ToolCall>,
23}
24
25#[async_trait]
27pub trait LanguageModel: Send + Sync {
28 async fn complete_chat(
29 &self,
30 messages: &[Message],
31 tools: &[ToolDescription],
32 stream: bool,
33 ) -> Result<ModelCompletion>;
34}
35
36fn coalesce_error(status: reqwest::StatusCode, body: &str, provider: &str) -> AgnoError {
37 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
38 return AgnoError::LanguageModel(format!("{provider} rate limit exceeded: {body}"));
39 }
40 AgnoError::LanguageModel(format!("{provider} request failed with {}: {body}", status))
41}
42
43fn serialize_tool_arguments(args: &Value) -> String {
44 serde_json::to_string(args).unwrap_or_else(|_| args.to_string())
45}
46
47#[derive(Clone)]
48pub struct OpenAIClient {
49 http: reqwest::Client,
50 model: String,
51 api_key: String,
52 base_url: String,
53 organization: Option<String>,
54}
55
56impl OpenAIClient {
57 pub fn new(api_key: impl Into<String>) -> Self {
58 Self {
59 http: reqwest::Client::new(),
60 model: "gpt-4-turbo-preview".to_string(),
61 api_key: api_key.into(),
62 base_url: "https://api.openai.com/v1".to_string(),
63 organization: None,
64 }
65 }
66
67 pub fn from_env() -> Result<Self> {
68 let api_key = std::env::var("OPENAI_API_KEY")
69 .map_err(|_| AgnoError::LanguageModel("OPENAI_API_KEY not found".into()))?;
70 Ok(Self::new(api_key))
71 }
72
73 pub fn with_model(mut self, model: impl Into<String>) -> Self {
74 self.model = model.into();
75 self
76 }
77
78 pub fn from_config(cfg: &ModelConfig) -> Result<Self> {
79 let api_key = cfg
80 .openai
81 .api_key
82 .clone()
83 .or_else(|| cfg.api_key.clone())
84 .ok_or_else(|| {
85 AgnoError::LanguageModel("missing OpenAI API key in model config".into())
86 })?;
87 let base_url = cfg
88 .openai
89 .endpoint
90 .clone()
91 .or_else(|| cfg.base_url.clone())
92 .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
93 Ok(Self {
94 http: reqwest::Client::builder()
95 .timeout(Duration::from_secs(60))
96 .build()
97 .map_err(|err| AgnoError::LanguageModel(format!("http client error: {err}")))?,
98 model: cfg.model.clone(),
99 api_key,
100 base_url,
101 organization: cfg
102 .openai
103 .organization
104 .clone()
105 .or_else(|| cfg.organization.clone()),
106 })
107 }
108
109 fn to_openai_messages(&self, messages: &[Message]) -> Vec<OpenAiMessage> {
110 let mut built = Vec::new();
111 for message in messages {
112 let role = match message.role {
113 Role::System => "system",
114 Role::User => "user",
115 Role::Assistant => "assistant",
116 Role::Tool => "tool",
117 }
118 .to_string();
119
120 let mut tool_calls = None;
121 if let Some(call) = &message.tool_call {
122 tool_calls = Some(vec![OpenAiToolCall {
123 id: call.id.clone(),
124 r#type: "function".to_string(),
125 function: OpenAiFunctionCall {
126 name: call.name.clone(),
127 arguments: serialize_tool_arguments(&call.arguments),
128 },
129 }]);
130 }
131
132 let content = if message.role == Role::Tool {
133 message
134 .tool_result
135 .as_ref()
136 .map(|result| serialize_tool_arguments(&result.output))
137 .or_else(|| Some(message.content.clone()))
138 } else {
139 Some(message.content.clone())
140 };
141
142 let tool_call_id = message
143 .tool_result
144 .as_ref()
145 .and_then(|result| result.tool_call_id.clone());
146
147 built.push(OpenAiMessage {
148 role,
149 content,
150 tool_call_id,
151 tool_calls,
152 });
153 }
154 built
155 }
156
157 fn to_openai_tools(&self, tools: &[ToolDescription]) -> Option<Vec<OpenAiTool>> {
158 if tools.is_empty() {
159 return None;
160 }
161
162 Some(
163 tools
164 .iter()
165 .map(|tool| OpenAiTool {
166 r#type: "function".to_string(),
167 function: OpenAiFunction {
168 name: tool.name.clone(),
169 description: Some(tool.description.clone()),
170 parameters: tool.parameters.clone(),
171 },
172 })
173 .collect(),
174 )
175 }
176}
177
178#[async_trait]
179impl LanguageModel for OpenAIClient {
180 async fn complete_chat(
181 &self,
182 messages: &[Message],
183 tools: &[ToolDescription],
184 stream: bool,
185 ) -> Result<ModelCompletion> {
186 let payload = json!({
187 "model": self.model,
188 "messages": self.to_openai_messages(messages),
189 "tools": self.to_openai_tools(tools),
190 "tool_choice": if tools.is_empty() { Value::Null } else { Value::String("auto".to_string()) },
191 "stream": stream,
192 });
193
194 let mut builder = self
195 .http
196 .post(format!("{}/chat/completions", self.base_url))
197 .header(
198 reqwest::header::AUTHORIZATION,
199 format!("Bearer {}", self.api_key),
200 );
201 if let Some(org) = &self.organization {
202 builder = builder.header("OpenAI-Organization", org);
203 }
204 let resp = builder
205 .json(&payload)
206 .send()
207 .await
208 .map_err(|err| AgnoError::LanguageModel(format!("OpenAI request error: {err}")))?;
209
210 if !resp.status().is_success() {
211 let status = resp.status();
212 let body = resp.text().await.unwrap_or_default();
213 return Err(coalesce_error(status, &body, "openai"));
214 }
215
216 if stream {
217 let mut content = String::new();
218 let mut tool_calls: HashMap<String, OpenAiToolCallState> = HashMap::new();
219 let mut stream = resp.bytes_stream();
220 while let Some(chunk) = stream.next().await {
221 let chunk = chunk.map_err(|err| {
222 AgnoError::LanguageModel(format!("OpenAI stream error: {err}"))
223 })?;
224 let text = String::from_utf8_lossy(&chunk);
225 for line in text.lines() {
226 if !line.starts_with("data: ") {
227 continue;
228 }
229 let data = line.trim_start_matches("data: ").trim();
230 if data == "[DONE]" {
231 continue;
232 }
233 let parsed: OpenAiStreamChunk = serde_json::from_str(data).map_err(|err| {
234 AgnoError::LanguageModel(format!(
235 "OpenAI stream parse error `{data}`: {err}"
236 ))
237 })?;
238
239 for choice in parsed.choices {
240 if let Some(delta_content) = choice.delta.content {
241 content.push_str(&delta_content);
242 }
243 if let Some(calls) = choice.delta.tool_calls {
244 for delta_call in calls {
245 let id = delta_call
246 .id
247 .clone()
248 .unwrap_or_else(|| format!("call_{}", tool_calls.len()));
249 let state = tool_calls.entry(id.clone()).or_default();
250 if let Some(function) = delta_call.function {
251 if let Some(name) = function.name {
252 state.name = Some(name);
253 }
254 if let Some(args) = function.arguments {
255 state.arguments.push_str(&args);
256 }
257 }
258 state.id = Some(id);
259 }
260 }
261 }
262 }
263 }
264
265 let calls: Vec<ToolCall> = tool_calls
266 .into_values()
267 .filter_map(|state| {
268 let name = state.name?;
269 let args = serde_json::from_str(&state.arguments)
270 .unwrap_or_else(|_| Value::String(state.arguments.clone()));
271 Some(ToolCall {
272 id: state.id,
273 name,
274 arguments: args,
275 })
276 })
277 .collect();
278
279 return Ok(ModelCompletion {
280 content: if content.is_empty() {
281 None
282 } else {
283 Some(content)
284 },
285 tool_calls: calls,
286 });
287 }
288
289 let body: OpenAiResponse = resp.json().await.map_err(|err| {
290 AgnoError::LanguageModel(format!("OpenAI response parse error: {err}"))
291 })?;
292
293 let first = body
294 .choices
295 .into_iter()
296 .next()
297 .ok_or_else(|| AgnoError::LanguageModel("OpenAI returned no choices".into()))?;
298
299 let mut tool_calls = Vec::new();
300 if let Some(calls) = first.message.tool_calls {
301 for call in calls {
302 let args = serde_json::from_str(&call.function.arguments)
303 .unwrap_or_else(|_| Value::String(call.function.arguments.clone()));
304 tool_calls.push(ToolCall {
305 id: call.id,
306 name: call.function.name,
307 arguments: args,
308 });
309 }
310 }
311
312 Ok(ModelCompletion {
313 content: first.message.content,
314 tool_calls,
315 })
316 }
317}
318
319#[derive(Clone)]
320pub struct AnthropicClient {
321 http: reqwest::Client,
322 model: String,
323 api_key: String,
324 endpoint: String,
325}
326
327impl AnthropicClient {
328 pub fn from_config(cfg: &ModelConfig) -> Result<Self> {
329 let api_key = cfg
330 .anthropic
331 .api_key
332 .clone()
333 .or_else(|| cfg.api_key.clone())
334 .ok_or_else(|| {
335 AgnoError::LanguageModel("missing Anthropic API key in model config".into())
336 })?;
337 let endpoint = cfg
338 .anthropic
339 .endpoint
340 .clone()
341 .unwrap_or_else(|| "https://api.anthropic.com/v1/messages".to_string());
342 Ok(Self {
343 http: reqwest::Client::builder()
344 .timeout(Duration::from_secs(60))
345 .build()
346 .map_err(|err| AgnoError::LanguageModel(format!("http client error: {err}")))?,
347 model: cfg.model.clone(),
348 api_key,
349 endpoint,
350 })
351 }
352
353 fn to_messages(&self, messages: &[Message]) -> Vec<AnthropicMessage> {
354 messages
355 .iter()
356 .filter_map(|message| match message.role {
357 Role::System => None,
358 Role::User | Role::Assistant | Role::Tool => Some(AnthropicMessage {
359 role: match message.role {
360 Role::User => "user",
361 Role::Assistant | Role::Tool => "assistant",
362 Role::System => unreachable!(),
363 }
364 .to_string(),
365 content: vec![AnthropicContentBlock {
366 r#type: "text".to_string(),
367 text: Some(message.content.clone()),
368 name: None,
369 input_schema: None,
370 }],
371 }),
372 })
373 .collect()
374 }
375
376 fn to_tools(&self, tools: &[ToolDescription]) -> Option<Vec<AnthropicTool>> {
377 if tools.is_empty() {
378 return None;
379 }
380 Some(
381 tools
382 .iter()
383 .map(|tool| AnthropicTool {
384 name: tool.name.clone(),
385 description: tool.description.clone(),
386 input_schema: tool
387 .parameters
388 .clone()
389 .unwrap_or_else(|| json!({"type":"object"})),
390 })
391 .collect(),
392 )
393 }
394}
395
396#[async_trait]
397impl LanguageModel for AnthropicClient {
398 async fn complete_chat(
399 &self,
400 messages: &[Message],
401 tools: &[ToolDescription],
402 stream: bool,
403 ) -> Result<ModelCompletion> {
404 let system = messages
405 .iter()
406 .find(|m| m.role == Role::System)
407 .map(|m| m.content.clone());
408 let payload = json!({
409 "model": self.model,
410 "system": system,
411 "messages": self.to_messages(messages),
412 "tools": self.to_tools(tools),
413 "stream": stream,
414 });
415
416 let resp = self
417 .http
418 .post(&self.endpoint)
419 .header("x-api-key", &self.api_key)
420 .header("anthropic-version", "2023-06-01")
421 .json(&payload)
422 .send()
423 .await
424 .map_err(|err| AgnoError::LanguageModel(format!("Anthropic request error: {err}")))?;
425
426 if !resp.status().is_success() {
427 let status = resp.status();
428 let body = resp.text().await.unwrap_or_default();
429 return Err(coalesce_error(status, &body, "anthropic"));
430 }
431
432 if stream {
433 let mut content = String::new();
434 let mut stream = resp.bytes_stream();
435 while let Some(chunk) = stream.next().await {
436 let chunk = chunk.map_err(|err| {
437 AgnoError::LanguageModel(format!("Anthropic stream error: {err}"))
438 })?;
439 let text = String::from_utf8_lossy(&chunk);
440 for line in text.lines() {
441 if !line.starts_with("data: ") {
442 continue;
443 }
444 let data = line.trim_start_matches("data: ").trim();
445 if data == "[DONE]" || data.is_empty() {
446 continue;
447 }
448 let parsed: AnthropicStreamChunk =
449 serde_json::from_str(data).map_err(|err| {
450 AgnoError::LanguageModel(format!(
451 "Anthropic stream parse error `{data}`: {err}"
452 ))
453 })?;
454 if let Some(text) = parsed.delta.text {
455 content.push_str(&text);
456 }
457 }
458 }
459
460 return Ok(ModelCompletion {
461 content: if content.is_empty() {
462 None
463 } else {
464 Some(content)
465 },
466 tool_calls: Vec::new(),
467 });
468 }
469
470 let parsed: AnthropicResponse = resp.json().await.map_err(|err| {
471 AgnoError::LanguageModel(format!("Anthropic response parse error: {err}"))
472 })?;
473
474 let content = parsed
475 .content
476 .iter()
477 .filter_map(|block| block.text.clone())
478 .collect::<Vec<String>>()
479 .join("");
480
481 Ok(ModelCompletion {
482 content: if content.is_empty() {
483 None
484 } else {
485 Some(content)
486 },
487 tool_calls: Vec::new(),
488 })
489 }
490}
491
492#[derive(Clone)]
493pub struct GeminiClient {
494 http: reqwest::Client,
495 model: String,
496 api_key: String,
497 endpoint: String,
498}
499
500impl GeminiClient {
501 pub fn from_config(cfg: &ModelConfig) -> Result<Self> {
502 let api_key = cfg
503 .gemini
504 .api_key
505 .clone()
506 .or_else(|| cfg.api_key.clone())
507 .ok_or_else(|| {
508 AgnoError::LanguageModel("missing Gemini API key in model config".into())
509 })?;
510 let endpoint = cfg
511 .gemini
512 .endpoint
513 .clone()
514 .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string());
515 Ok(Self {
516 http: reqwest::Client::builder()
517 .timeout(Duration::from_secs(60))
518 .build()
519 .map_err(|err| AgnoError::LanguageModel(format!("http client error: {err}")))?,
520 model: cfg.model.clone(),
521 api_key,
522 endpoint,
523 })
524 }
525
526 fn to_contents(&self, messages: &[Message]) -> Vec<GeminiMessage> {
527 messages
528 .iter()
529 .filter_map(|message| {
530 let role = match message.role {
531 Role::User => "user",
532 Role::Assistant => "model",
533 Role::System => "system",
534 Role::Tool => "user",
535 };
536 Some(GeminiMessage {
537 role: role.to_string(),
538 parts: vec![GeminiPart {
539 text: message.content.clone(),
540 }],
541 })
542 })
543 .collect()
544 }
545}
546
547#[async_trait]
548impl LanguageModel for GeminiClient {
549 async fn complete_chat(
550 &self,
551 messages: &[Message],
552 _tools: &[ToolDescription],
553 _stream: bool,
554 ) -> Result<ModelCompletion> {
555 let payload = json!({
556 "contents": self.to_contents(messages),
557 });
558 let resp = self
559 .http
560 .post(format!(
561 "{}/models/{}:generateContent?key={}",
562 self.endpoint, self.model, self.api_key
563 ))
564 .json(&payload)
565 .send()
566 .await
567 .map_err(|err| AgnoError::LanguageModel(format!("Gemini request error: {err}")))?;
568
569 if !resp.status().is_success() {
570 let status = resp.status();
571 let body = resp.text().await.unwrap_or_default();
572 return Err(coalesce_error(status, &body, "gemini"));
573 }
574
575 let parsed: GeminiResponse = resp.json().await.map_err(|err| {
576 AgnoError::LanguageModel(format!("Gemini response parse error: {err}"))
577 })?;
578
579 let content = parsed
580 .candidates
581 .get(0)
582 .and_then(|cand| cand.content.parts.get(0))
583 .map(|part| part.text.clone())
584 .unwrap_or_default();
585
586 Ok(ModelCompletion {
587 content: if content.is_empty() {
588 None
589 } else {
590 Some(content)
591 },
592 tool_calls: Vec::new(),
593 })
594 }
595}
596
597#[derive(Clone)]
598pub struct CohereClient {
599 http: reqwest::Client,
600 model: String,
601 api_key: String,
602 endpoint: String,
603}
604
605impl CohereClient {
606 pub fn new(api_key: impl Into<String>) -> Self {
607 Self {
608 http: reqwest::Client::builder()
609 .timeout(Duration::from_secs(60))
610 .build()
611 .expect("failed to build http client"),
612 model: "command-a-03-2025".to_string(),
613 api_key: api_key.into(),
614 endpoint: "https://api.cohere.ai/v2/chat".to_string(),
615 }
616 }
617
618 pub fn with_model(mut self, model: impl Into<String>) -> Self {
619 self.model = model.into();
620 self
621 }
622
623 pub fn from_config(cfg: &ModelConfig) -> Result<Self> {
624 let api_key = cfg
625 .cohere
626 .api_key
627 .clone()
628 .or_else(|| cfg.api_key.clone())
629 .ok_or_else(|| {
630 AgnoError::LanguageModel("missing Cohere API key in model config".into())
631 })?;
632 let endpoint = cfg
633 .cohere
634 .endpoint
635 .clone()
636 .unwrap_or_else(|| "https://api.cohere.ai/v2/chat".to_string());
637 Ok(Self {
638 http: reqwest::Client::builder()
639 .timeout(Duration::from_secs(60))
640 .build()
641 .map_err(|err| AgnoError::LanguageModel(format!("http client error: {err}")))?,
642 model: cfg.model.clone(),
643 api_key,
644 endpoint,
645 })
646 }
647
648 fn to_messages(&self, messages: &[Message]) -> Vec<CohereMessage> {
649 messages
650 .iter()
651 .map(|message| {
652 let role = match message.role {
653 Role::System => "system",
654 Role::User => "user",
655 Role::Assistant => "assistant",
656 Role::Tool => "tool",
657 };
658 CohereMessage {
659 role: role.to_string(),
660 content: message.content.clone(),
661 }
662 })
663 .collect()
664 }
665
666 fn to_tools(&self, tools: &[ToolDescription]) -> Option<Vec<CohereTool>> {
667 if tools.is_empty() {
668 return None;
669 }
670 Some(
671 tools
672 .iter()
673 .map(|tool| CohereTool {
674 r#type: "function".to_string(),
675 function: CohereFunction {
676 name: tool.name.clone(),
677 description: Some(tool.description.clone()),
678 parameters: tool.parameters.clone(),
679 },
680 })
681 .collect(),
682 )
683 }
684}
685
686#[async_trait]
687impl LanguageModel for CohereClient {
688 async fn complete_chat(
689 &self,
690 messages: &[Message],
691 tools: &[ToolDescription],
692 stream: bool,
693 ) -> Result<ModelCompletion> {
694 let payload = json!({
695 "model": self.model,
696 "messages": self.to_messages(messages),
697 "tools": self.to_tools(tools),
698 "stream": stream,
699 });
700
701 let resp = self
702 .http
703 .post(&self.endpoint)
704 .header("Authorization", format!("Bearer {}", self.api_key))
705 .header("Content-Type", "application/json")
706 .json(&payload)
707 .send()
708 .await
709 .map_err(|err| AgnoError::LanguageModel(format!("Cohere request error: {err}")))?;
710
711 if !resp.status().is_success() {
712 let status = resp.status();
713 let body = resp.text().await.unwrap_or_default();
714 return Err(coalesce_error(status, &body, "cohere"));
715 }
716
717 if stream {
718 let mut content = String::new();
719 let tool_calls_map: HashMap<String, OpenAiToolCallState> = HashMap::new();
720 let mut stream = resp.bytes_stream();
721 while let Some(chunk) = stream.next().await {
722 let chunk = chunk.map_err(|err| {
723 AgnoError::LanguageModel(format!("Cohere stream error: {err}"))
724 })?;
725 let text = String::from_utf8_lossy(&chunk);
726 for line in text.lines() {
727 if !line.starts_with("data: ") {
728 continue;
729 }
730 let data = line.trim_start_matches("data: ").trim();
731 if data == "[DONE]" || data.is_empty() {
732 continue;
733 }
734 if let Ok(parsed) = serde_json::from_str::<CohereStreamChunk>(data) {
735 if let Some(delta) = parsed.delta {
736 if let Some(msg) = delta.message {
737 if let Some(c) = msg.content {
738 if let Some(text_content) = c.get("text") {
739 if let Some(t) = text_content.as_str() {
740 content.push_str(t);
741 }
742 }
743 }
744 }
745 }
746 }
747 }
748 }
749
750 let calls: Vec<ToolCall> = tool_calls_map
751 .into_values()
752 .filter_map(|state| {
753 let name = state.name?;
754 let args = serde_json::from_str(&state.arguments)
755 .unwrap_or_else(|_| Value::String(state.arguments.clone()));
756 Some(ToolCall {
757 id: state.id,
758 name,
759 arguments: args,
760 })
761 })
762 .collect();
763
764 return Ok(ModelCompletion {
765 content: if content.is_empty() {
766 None
767 } else {
768 Some(content)
769 },
770 tool_calls: calls,
771 });
772 }
773
774 let body: CohereResponse = resp.json().await.map_err(|err| {
775 AgnoError::LanguageModel(format!("Cohere response parse error: {err}"))
776 })?;
777
778 let content = body.message.and_then(|m| {
779 m.content.and_then(|c| {
780 if let Some(arr) = c.as_array() {
781 let mut text = String::new();
782 for item in arr {
783 if let Some(t) = item.get("text").and_then(|v| v.as_str()) {
785 text.push_str(t);
786 }
787 }
788 if text.is_empty() { None } else { Some(text) }
789 } else {
790 c.get("text").and_then(|v| v.as_str().map(|s| s.to_string()))
791 }
792 })
793 });
794
795 let mut tool_calls = Vec::new();
796 if let Some(calls) = body.tool_calls {
797 for call in calls {
798 let args = serde_json::from_str(&call.function.arguments)
799 .unwrap_or_else(|_| Value::String(call.function.arguments.clone()));
800 tool_calls.push(ToolCall {
801 id: call.id,
802 name: call.function.name,
803 arguments: args,
804 });
805 }
806 }
807
808 Ok(ModelCompletion {
809 content,
810 tool_calls,
811 })
812 }
813}
814
815#[derive(Clone)]
822pub struct GroqClient {
823 http: reqwest::Client,
824 model: String,
825 api_key: String,
826 base_url: String,
827}
828
829impl GroqClient {
830 pub fn new(api_key: impl Into<String>) -> Self {
831 Self {
832 http: reqwest::Client::builder()
833 .timeout(Duration::from_secs(120))
834 .build()
835 .expect("failed to build http client"),
836 model: "llama-3.3-70b-versatile".to_string(),
837 api_key: api_key.into(),
838 base_url: "https://api.groq.com/openai/v1".to_string(),
839 }
840 }
841
842 pub fn with_model(mut self, model: impl Into<String>) -> Self {
843 self.model = model.into();
844 self
845 }
846
847 pub fn from_env() -> Result<Self> {
848 let api_key = std::env::var("GROQ_API_KEY")
849 .map_err(|_| AgnoError::LanguageModel("GROQ_API_KEY not set".into()))?;
850 Ok(Self::new(api_key))
851 }
852}
853
854#[async_trait]
855impl LanguageModel for GroqClient {
856 async fn complete_chat(
857 &self,
858 messages: &[Message],
859 tools: &[ToolDescription],
860 stream: bool,
861 ) -> Result<ModelCompletion> {
862 let oai_messages: Vec<Value> = messages
864 .iter()
865 .map(|m| {
866 let role = match m.role {
867 Role::System => "system",
868 Role::User => "user",
869 Role::Assistant => "assistant",
870 Role::Tool => "tool",
871 };
872 let mut msg = json!({
873 "role": role,
874 "content": m.content.clone()
875 });
876 if let Some(ref result) = m.tool_result {
877 if let Some(ref call_id) = result.tool_call_id {
878 msg["tool_call_id"] = json!(call_id);
879 }
880 }
881 msg
882 })
883 .collect();
884
885 let mut body = json!({
886 "model": self.model,
887 "messages": oai_messages,
888 "stream": stream
889 });
890
891 if !tools.is_empty() {
892 let oai_tools: Vec<Value> = tools
893 .iter()
894 .map(|t| {
895 json!({
896 "type": "function",
897 "function": {
898 "name": t.name,
899 "description": t.description,
900 "parameters": t.parameters
901 }
902 })
903 })
904 .collect();
905 body["tools"] = json!(oai_tools);
906 }
907
908 let resp = self
909 .http
910 .post(format!("{}/chat/completions", self.base_url))
911 .header("Authorization", format!("Bearer {}", self.api_key))
912 .header("Content-Type", "application/json")
913 .json(&body)
914 .send()
915 .await
916 .map_err(|e| AgnoError::LanguageModel(format!("Groq request failed: {e}")))?;
917
918 let status = resp.status();
919 if !status.is_success() {
920 let body = resp.text().await.unwrap_or_default();
921 return Err(coalesce_error(status, &body, "Groq"));
922 }
923
924 let json: Value = resp
925 .json()
926 .await
927 .map_err(|e| AgnoError::LanguageModel(format!("Groq parse error: {e}")))?;
928
929 let choice = &json["choices"][0]["message"];
930 let content = choice["content"].as_str().map(String::from);
931
932 let mut tool_calls = Vec::new();
933 if let Some(calls) = choice["tool_calls"].as_array() {
934 for call in calls {
935 let name = call["function"]["name"].as_str().unwrap_or("").to_string();
936 let args_str = call["function"]["arguments"].as_str().unwrap_or("{}");
937 let args: Value = serde_json::from_str(args_str).unwrap_or(json!({}));
938 tool_calls.push(ToolCall {
939 id: call["id"].as_str().map(String::from),
940 name,
941 arguments: args,
942 });
943 }
944 }
945
946 Ok(ModelCompletion { content, tool_calls })
947 }
948}
949
950#[derive(Clone)]
957pub struct OllamaClient {
958 http: reqwest::Client,
959 model: String,
960 base_url: String,
961}
962
963impl OllamaClient {
964 pub fn new() -> Self {
965 Self {
966 http: reqwest::Client::builder()
967 .timeout(Duration::from_secs(300)) .build()
969 .expect("failed to build http client"),
970 model: "llama3.1".to_string(),
971 base_url: "http://localhost:11434".to_string(),
972 }
973 }
974
975 pub fn with_model(mut self, model: impl Into<String>) -> Self {
976 self.model = model.into();
977 self
978 }
979
980 pub fn with_host(mut self, host: impl Into<String>) -> Self {
981 self.base_url = host.into();
982 self
983 }
984
985 pub fn from_env() -> Self {
986 let mut client = Self::new();
987 if let Ok(host) = std::env::var("OLLAMA_HOST") {
988 client.base_url = host;
989 }
990 if let Ok(model) = std::env::var("OLLAMA_MODEL") {
991 client.model = model;
992 }
993 client
994 }
995}
996
997impl Default for OllamaClient {
998 fn default() -> Self {
999 Self::new()
1000 }
1001}
1002
1003#[async_trait]
1004impl LanguageModel for OllamaClient {
1005 async fn complete_chat(
1006 &self,
1007 messages: &[Message],
1008 tools: &[ToolDescription],
1009 _stream: bool,
1010 ) -> Result<ModelCompletion> {
1011 let ollama_messages: Vec<Value> = messages
1013 .iter()
1014 .map(|m| {
1015 let role = match m.role {
1016 Role::System => "system",
1017 Role::User => "user",
1018 Role::Assistant => "assistant",
1019 Role::Tool => "tool",
1020 };
1021 json!({
1022 "role": role,
1023 "content": m.content.clone()
1024 })
1025 })
1026 .collect();
1027
1028 let mut body = json!({
1029 "model": self.model,
1030 "messages": ollama_messages,
1031 "stream": false
1032 });
1033
1034 if !tools.is_empty() {
1035 let ollama_tools: Vec<Value> = tools
1036 .iter()
1037 .map(|t| {
1038 json!({
1039 "type": "function",
1040 "function": {
1041 "name": t.name,
1042 "description": t.description,
1043 "parameters": t.parameters
1044 }
1045 })
1046 })
1047 .collect();
1048 body["tools"] = json!(ollama_tools);
1049 }
1050
1051 let resp = self
1052 .http
1053 .post(format!("{}/api/chat", self.base_url))
1054 .header("Content-Type", "application/json")
1055 .json(&body)
1056 .send()
1057 .await
1058 .map_err(|e| AgnoError::LanguageModel(format!("Ollama request failed: {e}")))?;
1059
1060 let status = resp.status();
1061 if !status.is_success() {
1062 let body = resp.text().await.unwrap_or_default();
1063 return Err(coalesce_error(status, &body, "Ollama"));
1064 }
1065
1066 let json: Value = resp
1067 .json()
1068 .await
1069 .map_err(|e| AgnoError::LanguageModel(format!("Ollama parse error: {e}")))?;
1070
1071 let message = &json["message"];
1072 let content = message["content"].as_str().map(String::from);
1073
1074 let mut tool_calls = Vec::new();
1075 if let Some(calls) = message["tool_calls"].as_array() {
1076 for call in calls {
1077 let func = &call["function"];
1078 let name = func["name"].as_str().unwrap_or("").to_string();
1079 let args = func["arguments"].clone();
1080 tool_calls.push(ToolCall {
1081 id: None,
1082 name,
1083 arguments: args,
1084 });
1085 }
1086 }
1087
1088 Ok(ModelCompletion { content, tool_calls })
1089 }
1090}
1091
1092#[derive(Clone)]
1099pub struct MistralClient {
1100 http: reqwest::Client,
1101 model: String,
1102 api_key: String,
1103 base_url: String,
1104}
1105
1106impl MistralClient {
1107 pub fn new(api_key: impl Into<String>) -> Self {
1108 Self {
1109 http: reqwest::Client::builder()
1110 .timeout(Duration::from_secs(120))
1111 .build()
1112 .expect("failed to build http client"),
1113 model: "mistral-large-latest".to_string(),
1114 api_key: api_key.into(),
1115 base_url: "https://api.mistral.ai/v1".to_string(),
1116 }
1117 }
1118
1119 pub fn with_model(mut self, model: impl Into<String>) -> Self {
1120 self.model = model.into();
1121 self
1122 }
1123
1124 pub fn from_env() -> Result<Self> {
1125 let api_key = std::env::var("MISTRAL_API_KEY")
1126 .map_err(|_| AgnoError::LanguageModel("MISTRAL_API_KEY not set".into()))?;
1127 Ok(Self::new(api_key))
1128 }
1129}
1130
1131#[async_trait]
1132impl LanguageModel for MistralClient {
1133 async fn complete_chat(
1134 &self,
1135 messages: &[Message],
1136 tools: &[ToolDescription],
1137 stream: bool,
1138 ) -> Result<ModelCompletion> {
1139 let mistral_messages: Vec<Value> = messages
1141 .iter()
1142 .map(|m| {
1143 let role = match m.role {
1144 Role::System => "system",
1145 Role::User => "user",
1146 Role::Assistant => "assistant",
1147 Role::Tool => "tool",
1148 };
1149
1150 let mut msg = json!({
1151 "role": role,
1152 "content": m.content.clone()
1153 });
1154
1155 if m.role == Role::Tool {
1157 if let Some(ref tc) = m.tool_call {
1158 if let Some(ref id) = tc.id {
1159 msg["tool_call_id"] = json!(id);
1160 }
1161 }
1162 }
1163
1164 if let Some(ref tc) = m.tool_call {
1166 if m.role == Role::Assistant {
1167 msg["tool_calls"] = json!([{
1168 "id": tc.id.clone().unwrap_or_default(),
1169 "type": "function",
1170 "function": {
1171 "name": tc.name,
1172 "arguments": serialize_tool_arguments(&tc.arguments)
1173 }
1174 }]);
1175 msg["content"] = json!(null);
1176 }
1177 }
1178
1179 msg
1180 })
1181 .collect();
1182
1183 let mut body = json!({
1184 "model": self.model,
1185 "messages": mistral_messages,
1186 "stream": stream
1187 });
1188
1189 if !tools.is_empty() {
1190 let mistral_tools: Vec<Value> = tools
1191 .iter()
1192 .map(|t| {
1193 json!({
1194 "type": "function",
1195 "function": {
1196 "name": t.name,
1197 "description": t.description,
1198 "parameters": t.parameters.clone().unwrap_or(json!({"type": "object", "properties": {}}))
1199 }
1200 })
1201 })
1202 .collect();
1203 body["tools"] = json!(mistral_tools);
1204 body["tool_choice"] = json!("auto");
1205 }
1206
1207 let resp = self
1208 .http
1209 .post(format!("{}/chat/completions", self.base_url))
1210 .header("Authorization", format!("Bearer {}", self.api_key))
1211 .header("Content-Type", "application/json")
1212 .json(&body)
1213 .send()
1214 .await
1215 .map_err(|e| AgnoError::LanguageModel(format!("Mistral request failed: {e}")))?;
1216
1217 let status = resp.status();
1218 if !status.is_success() {
1219 let body = resp.text().await.unwrap_or_default();
1220 return Err(coalesce_error(status, &body, "Mistral"));
1221 }
1222
1223 let json: Value = resp
1225 .json()
1226 .await
1227 .map_err(|e| AgnoError::LanguageModel(format!("Mistral parse error: {e}")))?;
1228
1229 let choice = json["choices"]
1230 .as_array()
1231 .and_then(|c| c.first())
1232 .ok_or_else(|| AgnoError::LanguageModel("Mistral returned no choices".into()))?;
1233
1234 let message = &choice["message"];
1235 let content = message["content"].as_str().map(String::from);
1236
1237 let mut tool_calls = Vec::new();
1238 if let Some(calls) = message["tool_calls"].as_array() {
1239 for call in calls {
1240 let id = call["id"].as_str().map(String::from);
1241 let func = &call["function"];
1242 let name = func["name"].as_str().unwrap_or("").to_string();
1243 let args_str = func["arguments"].as_str().unwrap_or("{}");
1244 let args: Value = serde_json::from_str(args_str).unwrap_or(json!({}));
1245 tool_calls.push(ToolCall {
1246 id,
1247 name,
1248 arguments: args,
1249 });
1250 }
1251 }
1252
1253 Ok(ModelCompletion { content, tool_calls })
1254 }
1255}
1256
1257#[derive(Clone)]
1263pub struct AzureOpenAIClient {
1264 http: reqwest::Client,
1265 endpoint: String,
1266 api_key: String,
1267 deployment: String,
1268 api_version: String,
1269}
1270
1271impl AzureOpenAIClient {
1272 pub fn new(
1273 endpoint: impl Into<String>,
1274 api_key: impl Into<String>,
1275 deployment: impl Into<String>,
1276 ) -> Self {
1277 Self {
1278 http: reqwest::Client::builder()
1279 .timeout(Duration::from_secs(120))
1280 .build()
1281 .expect("failed to build http client"),
1282 endpoint: endpoint.into(),
1283 api_key: api_key.into(),
1284 deployment: deployment.into(),
1285 api_version: "2024-02-01".to_string(),
1286 }
1287 }
1288
1289 pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
1290 self.api_version = version.into();
1291 self
1292 }
1293
1294 pub fn from_env() -> Result<Self> {
1295 let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT")
1296 .map_err(|_| AgnoError::LanguageModel("AZURE_OPENAI_ENDPOINT not set".into()))?;
1297 let api_key = std::env::var("AZURE_OPENAI_API_KEY")
1298 .map_err(|_| AgnoError::LanguageModel("AZURE_OPENAI_API_KEY not set".into()))?;
1299 let deployment = std::env::var("AZURE_OPENAI_DEPLOYMENT")
1300 .unwrap_or_else(|_| "gpt-4".to_string());
1301 Ok(Self::new(endpoint, api_key, deployment))
1302 }
1303}
1304
1305#[async_trait]
1306impl LanguageModel for AzureOpenAIClient {
1307 async fn complete_chat(
1308 &self,
1309 messages: &[Message],
1310 tools: &[ToolDescription],
1311 stream: bool,
1312 ) -> Result<ModelCompletion> {
1313 let azure_messages: Vec<Value> = messages
1315 .iter()
1316 .map(|m| {
1317 let role = match m.role {
1318 Role::System => "system",
1319 Role::User => "user",
1320 Role::Assistant => "assistant",
1321 Role::Tool => "tool",
1322 };
1323
1324 let mut msg = json!({
1325 "role": role,
1326 "content": m.content.clone()
1327 });
1328
1329 if m.role == Role::Tool {
1330 if let Some(ref tc) = m.tool_call {
1331 if let Some(ref id) = tc.id {
1332 msg["tool_call_id"] = json!(id);
1333 }
1334 }
1335 }
1336
1337 if let Some(ref tc) = m.tool_call {
1338 if m.role == Role::Assistant {
1339 msg["tool_calls"] = json!([{
1340 "id": tc.id.clone().unwrap_or_default(),
1341 "type": "function",
1342 "function": {
1343 "name": tc.name,
1344 "arguments": serialize_tool_arguments(&tc.arguments)
1345 }
1346 }]);
1347 msg["content"] = json!(null);
1348 }
1349 }
1350
1351 msg
1352 })
1353 .collect();
1354
1355 let mut body = json!({
1356 "messages": azure_messages,
1357 "stream": stream
1358 });
1359
1360 if !tools.is_empty() {
1361 let azure_tools: Vec<Value> = tools
1362 .iter()
1363 .map(|t| {
1364 json!({
1365 "type": "function",
1366 "function": {
1367 "name": t.name,
1368 "description": t.description,
1369 "parameters": t.parameters.clone().unwrap_or(json!({"type": "object", "properties": {}}))
1370 }
1371 })
1372 })
1373 .collect();
1374 body["tools"] = json!(azure_tools);
1375 body["tool_choice"] = json!("auto");
1376 }
1377
1378 let url = format!(
1379 "{}/openai/deployments/{}/chat/completions?api-version={}",
1380 self.endpoint, self.deployment, self.api_version
1381 );
1382
1383 let resp = self
1384 .http
1385 .post(&url)
1386 .header("api-key", &self.api_key)
1387 .header("Content-Type", "application/json")
1388 .json(&body)
1389 .send()
1390 .await
1391 .map_err(|e| AgnoError::LanguageModel(format!("Azure OpenAI request failed: {e}")))?;
1392
1393 let status = resp.status();
1394 if !status.is_success() {
1395 let body = resp.text().await.unwrap_or_default();
1396 return Err(coalesce_error(status, &body, "Azure OpenAI"));
1397 }
1398
1399 let json: Value = resp
1400 .json()
1401 .await
1402 .map_err(|e| AgnoError::LanguageModel(format!("Azure OpenAI parse error: {e}")))?;
1403
1404 let choice = json["choices"]
1405 .as_array()
1406 .and_then(|c| c.first())
1407 .ok_or_else(|| AgnoError::LanguageModel("Azure OpenAI returned no choices".into()))?;
1408
1409 let message = &choice["message"];
1410 let content = message["content"].as_str().map(String::from);
1411
1412 let mut tool_calls = Vec::new();
1413 if let Some(calls) = message["tool_calls"].as_array() {
1414 for call in calls {
1415 let id = call["id"].as_str().map(String::from);
1416 let func = &call["function"];
1417 let name = func["name"].as_str().unwrap_or("").to_string();
1418 let args_str = func["arguments"].as_str().unwrap_or("{}");
1419 let args: Value = serde_json::from_str(args_str).unwrap_or(json!({}));
1420 tool_calls.push(ToolCall {
1421 id,
1422 name,
1423 arguments: args,
1424 });
1425 }
1426 }
1427
1428 Ok(ModelCompletion { content, tool_calls })
1429 }
1430}
1431
1432#[derive(Clone)]
1439pub struct TogetherClient {
1440 http: reqwest::Client,
1441 model: String,
1442 api_key: String,
1443}
1444
1445impl TogetherClient {
1446 pub fn new(api_key: impl Into<String>) -> Self {
1447 Self {
1448 http: reqwest::Client::builder()
1449 .timeout(Duration::from_secs(120))
1450 .build()
1451 .expect("failed to build http client"),
1452 model: "meta-llama/Llama-3.3-70B-Instruct-Turbo".to_string(),
1453 api_key: api_key.into(),
1454 }
1455 }
1456
1457 pub fn with_model(mut self, model: impl Into<String>) -> Self {
1458 self.model = model.into();
1459 self
1460 }
1461
1462 pub fn from_env() -> Result<Self> {
1463 let api_key = std::env::var("TOGETHER_API_KEY")
1464 .map_err(|_| AgnoError::LanguageModel("TOGETHER_API_KEY not set".into()))?;
1465 Ok(Self::new(api_key))
1466 }
1467}
1468
1469#[async_trait]
1470impl LanguageModel for TogetherClient {
1471 async fn complete_chat(
1472 &self,
1473 messages: &[Message],
1474 tools: &[ToolDescription],
1475 stream: bool,
1476 ) -> Result<ModelCompletion> {
1477 let together_messages: Vec<Value> = messages
1478 .iter()
1479 .map(|m| {
1480 let role = match m.role {
1481 Role::System => "system",
1482 Role::User => "user",
1483 Role::Assistant => "assistant",
1484 Role::Tool => "tool",
1485 };
1486 json!({
1487 "role": role,
1488 "content": m.content.clone()
1489 })
1490 })
1491 .collect();
1492
1493 let mut body = json!({
1494 "model": self.model,
1495 "messages": together_messages,
1496 "stream": stream
1497 });
1498
1499 if !tools.is_empty() {
1500 let together_tools: Vec<Value> = tools
1501 .iter()
1502 .map(|t| {
1503 json!({
1504 "type": "function",
1505 "function": {
1506 "name": t.name,
1507 "description": t.description,
1508 "parameters": t.parameters.clone().unwrap_or(json!({"type": "object", "properties": {}}))
1509 }
1510 })
1511 })
1512 .collect();
1513 body["tools"] = json!(together_tools);
1514 }
1515
1516 let resp = self
1517 .http
1518 .post("https://api.together.xyz/v1/chat/completions")
1519 .header("Authorization", format!("Bearer {}", self.api_key))
1520 .header("Content-Type", "application/json")
1521 .json(&body)
1522 .send()
1523 .await
1524 .map_err(|e| AgnoError::LanguageModel(format!("Together request failed: {e}")))?;
1525
1526 let status = resp.status();
1527 if !status.is_success() {
1528 let body = resp.text().await.unwrap_or_default();
1529 return Err(coalesce_error(status, &body, "Together"));
1530 }
1531
1532 let json: Value = resp
1533 .json()
1534 .await
1535 .map_err(|e| AgnoError::LanguageModel(format!("Together parse error: {e}")))?;
1536
1537 let choice = json["choices"]
1538 .as_array()
1539 .and_then(|c| c.first())
1540 .ok_or_else(|| AgnoError::LanguageModel("Together returned no choices".into()))?;
1541
1542 let message = &choice["message"];
1543 let content = message["content"].as_str().map(String::from);
1544
1545 let mut tool_calls = Vec::new();
1546 if let Some(calls) = message["tool_calls"].as_array() {
1547 for call in calls {
1548 let id = call["id"].as_str().map(String::from);
1549 let func = &call["function"];
1550 let name = func["name"].as_str().unwrap_or("").to_string();
1551 let args_str = func["arguments"].as_str().unwrap_or("{}");
1552 let args: Value = serde_json::from_str(args_str).unwrap_or(json!({}));
1553 tool_calls.push(ToolCall {
1554 id,
1555 name,
1556 arguments: args,
1557 });
1558 }
1559 }
1560
1561 Ok(ModelCompletion { content, tool_calls })
1562 }
1563}
1564
1565#[derive(Clone)]
1572pub struct FireworksClient {
1573 http: reqwest::Client,
1574 model: String,
1575 api_key: String,
1576}
1577
1578impl FireworksClient {
1579 pub fn new(api_key: impl Into<String>) -> Self {
1580 Self {
1581 http: reqwest::Client::builder()
1582 .timeout(Duration::from_secs(120))
1583 .build()
1584 .expect("failed to build http client"),
1585 model: "accounts/fireworks/models/llama-v3p1-70b-instruct".to_string(),
1586 api_key: api_key.into(),
1587 }
1588 }
1589
1590 pub fn with_model(mut self, model: impl Into<String>) -> Self {
1591 self.model = model.into();
1592 self
1593 }
1594
1595 pub fn from_env() -> Result<Self> {
1596 let api_key = std::env::var("FIREWORKS_API_KEY")
1597 .map_err(|_| AgnoError::LanguageModel("FIREWORKS_API_KEY not set".into()))?;
1598 Ok(Self::new(api_key))
1599 }
1600}
1601
1602#[async_trait]
1603impl LanguageModel for FireworksClient {
1604 async fn complete_chat(
1605 &self,
1606 messages: &[Message],
1607 tools: &[ToolDescription],
1608 stream: bool,
1609 ) -> Result<ModelCompletion> {
1610 let fireworks_messages: Vec<Value> = messages
1611 .iter()
1612 .map(|m| {
1613 let role = match m.role {
1614 Role::System => "system",
1615 Role::User => "user",
1616 Role::Assistant => "assistant",
1617 Role::Tool => "tool",
1618 };
1619 json!({
1620 "role": role,
1621 "content": m.content.clone()
1622 })
1623 })
1624 .collect();
1625
1626 let mut body = json!({
1627 "model": self.model,
1628 "messages": fireworks_messages,
1629 "stream": stream
1630 });
1631
1632 if !tools.is_empty() {
1633 let fireworks_tools: Vec<Value> = tools
1634 .iter()
1635 .map(|t| {
1636 json!({
1637 "type": "function",
1638 "function": {
1639 "name": t.name,
1640 "description": t.description,
1641 "parameters": t.parameters.clone().unwrap_or(json!({"type": "object", "properties": {}}))
1642 }
1643 })
1644 })
1645 .collect();
1646 body["tools"] = json!(fireworks_tools);
1647 }
1648
1649 let resp = self
1650 .http
1651 .post("https://api.fireworks.ai/inference/v1/chat/completions")
1652 .header("Authorization", format!("Bearer {}", self.api_key))
1653 .header("Content-Type", "application/json")
1654 .json(&body)
1655 .send()
1656 .await
1657 .map_err(|e| AgnoError::LanguageModel(format!("Fireworks request failed: {e}")))?;
1658
1659 let status = resp.status();
1660 if !status.is_success() {
1661 let body = resp.text().await.unwrap_or_default();
1662 return Err(coalesce_error(status, &body, "Fireworks"));
1663 }
1664
1665 let json: Value = resp
1666 .json()
1667 .await
1668 .map_err(|e| AgnoError::LanguageModel(format!("Fireworks parse error: {e}")))?;
1669
1670 let choice = json["choices"]
1671 .as_array()
1672 .and_then(|c| c.first())
1673 .ok_or_else(|| AgnoError::LanguageModel("Fireworks returned no choices".into()))?;
1674
1675 let message = &choice["message"];
1676 let content = message["content"].as_str().map(String::from);
1677
1678 let mut tool_calls = Vec::new();
1679 if let Some(calls) = message["tool_calls"].as_array() {
1680 for call in calls {
1681 let id = call["id"].as_str().map(String::from);
1682 let func = &call["function"];
1683 let name = func["name"].as_str().unwrap_or("").to_string();
1684 let args_str = func["arguments"].as_str().unwrap_or("{}");
1685 let args: Value = serde_json::from_str(args_str).unwrap_or(json!({}));
1686 tool_calls.push(ToolCall {
1687 id,
1688 name,
1689 arguments: args,
1690 });
1691 }
1692 }
1693
1694 Ok(ModelCompletion { content, tool_calls })
1695 }
1696}
1697
1698#[derive(Clone)]
1705pub struct AwsBedrockClient {
1706 client: std::sync::Arc<aws_sdk_bedrockruntime::Client>,
1707 model_id: String,
1708}
1709
1710impl AwsBedrockClient {
1711 pub async fn new(region: Option<String>) -> Self {
1712 let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
1713 if let Some(r) = region {
1714 loader = loader.region(aws_config::Region::new(r));
1715 }
1716 let sdk_config = loader.load().await;
1717 let client = aws_sdk_bedrockruntime::Client::new(&sdk_config);
1718
1719 Self {
1720 client: std::sync::Arc::new(client),
1721 model_id: "anthropic.claude-3-sonnet-20240229-v1:0".to_string(),
1722 }
1723 }
1724
1725 pub fn with_model(mut self, model_id: impl Into<String>) -> Self {
1726 self.model_id = model_id.into();
1727 self
1728 }
1729}
1730
1731#[async_trait]
1732impl LanguageModel for AwsBedrockClient {
1733 async fn complete_chat(
1734 &self,
1735 messages: &[Message],
1736 tools: &[ToolDescription],
1737 _stream: bool, ) -> Result<ModelCompletion> {
1739 let system_prompt = messages
1741 .iter()
1742 .filter(|m| m.role == Role::System)
1743 .map(|m| m.content.clone())
1744 .collect::<Vec<_>>()
1745 .join("\n");
1746
1747 let mut bedrock_messages = Vec::new();
1748 for m in messages {
1749 if m.role == Role::System { continue; }
1750
1751 let role = match m.role {
1752 Role::User => "user",
1753 Role::Assistant => "assistant",
1754 Role::Tool => "user", _ => "user",
1756 };
1757
1758 let content = if m.role == Role::Tool {
1759 json!([{
1761 "type": "tool_result",
1762 "tool_use_id": m.tool_call.as_ref().and_then(|t| t.id.clone()).unwrap_or_default(),
1763 "content": m.content
1764 }])
1765 } else if let Some(ref tc) = m.tool_call {
1766 json!([{
1768 "type": "tool_use",
1769 "id": tc.id.clone().unwrap_or_default(),
1770 "name": tc.name,
1771 "input": tc.arguments
1772 }])
1773 } else {
1774 json!(m.content)
1775 };
1776
1777 bedrock_messages.push(json!({
1778 "role": role,
1779 "content": content
1780 }));
1781 }
1782
1783 let mut body = json!({
1784 "anthropic_version": "bedrock-2023-05-31",
1785 "max_tokens": 4096,
1786 "messages": bedrock_messages
1787 });
1788
1789 if !system_prompt.is_empty() {
1790 body["system"] = json!(system_prompt);
1791 }
1792
1793 if !tools.is_empty() {
1794 let tool_defs: Vec<Value> = tools.iter().map(|t| {
1795 json!({
1796 "name": t.name,
1797 "description": t.description,
1798 "input_schema": t.parameters.clone().unwrap_or(json!({"type": "object", "properties": {}}))
1799 })
1800 }).collect();
1801 body["tools"] = json!(tool_defs);
1802 }
1803
1804 let blob = aws_sdk_bedrockruntime::primitives::Blob::new(serde_json::to_vec(&body).unwrap());
1805
1806 let output = self.client
1807 .invoke_model()
1808 .model_id(&self.model_id)
1809 .body(blob)
1810 .send()
1811 .await
1812 .map_err(|e| AgnoError::LanguageModel(format!("Bedrock invocation failed: {}", e)))?;
1813
1814 let response_body: Value = serde_json::from_slice(output.body.as_ref())
1815 .map_err(|e| AgnoError::LanguageModel(format!("Failed to parse Bedrock response: {}", e)))?;
1816
1817 let mut content = None;
1818 let mut tool_calls = Vec::new();
1819
1820 if let Some(content_blocks) = response_body["content"].as_array() {
1821 let mut text_parts = Vec::new();
1822 for block in content_blocks {
1823 if block["type"] == "text" {
1824 if let Some(text) = block["text"].as_str() {
1825 text_parts.push(text);
1826 }
1827 } else if block["type"] == "tool_use" {
1828 let id = block["id"].as_str().map(String::from);
1829 let name = block["name"].as_str().unwrap_or_default().to_string();
1830 let args = block["input"].clone();
1831 tool_calls.push(ToolCall { id, name, arguments: args });
1832 }
1833 }
1834 if !text_parts.is_empty() {
1835 content = Some(text_parts.join("\n"));
1836 }
1837 }
1838
1839 Ok(ModelCompletion { content, tool_calls })
1840 }
1841}
1842
1843
1844pub struct StubModel {
1845 responses: Mutex<VecDeque<String>>,
1846}
1847
1848
1849impl StubModel {
1850 pub fn new(responses: Vec<String>) -> Arc<Self> {
1851 Arc::new(Self {
1852 responses: Mutex::new(responses.into()),
1853 })
1854 }
1855}
1856
1857#[derive(Debug, Deserialize)]
1858#[serde(tag = "action", rename_all = "snake_case")]
1859enum StubDirective {
1860 Respond { content: String },
1861 CallTool { name: String, arguments: Value },
1862}
1863
1864#[async_trait]
1865impl LanguageModel for StubModel {
1866 async fn complete_chat(
1867 &self,
1868 _messages: &[Message],
1869 _tools: &[ToolDescription],
1870 _stream: bool,
1871 ) -> Result<ModelCompletion> {
1872 let mut locked = self.responses.lock().expect("stub model poisoned");
1873 let raw = locked.pop_front().ok_or_else(|| {
1874 AgnoError::LanguageModel("StubModel ran out of scripted responses".into())
1875 })?;
1876
1877 match serde_json::from_str::<StubDirective>(&raw) {
1878 Ok(StubDirective::Respond { content }) => Ok(ModelCompletion {
1879 content: Some(content),
1880 tool_calls: Vec::new(),
1881 }),
1882 Ok(StubDirective::CallTool { name, arguments }) => Ok(ModelCompletion {
1883 content: None,
1884 tool_calls: vec![ToolCall {
1885 id: None,
1886 name,
1887 arguments,
1888 }],
1889 }),
1890 Err(_) => Ok(ModelCompletion {
1891 content: Some(raw),
1892 tool_calls: Vec::new(),
1893 }),
1894 }
1895 }
1896}
1897
1898#[derive(Debug, Serialize, Deserialize)]
1899struct OpenAiMessage {
1900 role: String,
1901 #[serde(skip_serializing_if = "Option::is_none")]
1902 content: Option<String>,
1903 #[serde(skip_serializing_if = "Option::is_none")]
1904 tool_call_id: Option<String>,
1905 #[serde(skip_serializing_if = "Option::is_none")]
1906 tool_calls: Option<Vec<OpenAiToolCall>>,
1907}
1908
1909#[derive(Debug, Serialize, Deserialize)]
1910struct OpenAiToolCall {
1911 #[serde(skip_serializing_if = "Option::is_none")]
1912 id: Option<String>,
1913 r#type: String,
1914 function: OpenAiFunctionCall,
1915}
1916
1917#[derive(Debug, Serialize, Deserialize)]
1918struct OpenAiFunctionCall {
1919 name: String,
1920 arguments: String,
1921}
1922
1923#[derive(Debug, Serialize, Deserialize)]
1924struct OpenAiTool {
1925 r#type: String,
1926 function: OpenAiFunction,
1927}
1928
1929#[derive(Debug, Serialize, Deserialize)]
1930struct OpenAiFunction {
1931 name: String,
1932 #[serde(skip_serializing_if = "Option::is_none")]
1933 description: Option<String>,
1934 #[serde(skip_serializing_if = "Option::is_none")]
1935 parameters: Option<Value>,
1936}
1937
1938#[derive(Debug, Deserialize)]
1939struct OpenAiResponse {
1940 choices: Vec<OpenAiChoice>,
1941}
1942
1943#[derive(Debug, Deserialize)]
1944struct OpenAiChoice {
1945 message: OpenAiChoiceMessage,
1946 #[allow(dead_code)]
1947 finish_reason: Option<String>,
1948}
1949
1950#[derive(Debug, Deserialize)]
1951struct OpenAiChoiceMessage {
1952 content: Option<String>,
1953 #[serde(default)]
1954 tool_calls: Option<Vec<OpenAiToolCall>>,
1955}
1956
1957#[derive(Default)]
1958struct OpenAiToolCallState {
1959 id: Option<String>,
1960 name: Option<String>,
1961 arguments: String,
1962}
1963
1964#[derive(Debug, Deserialize)]
1965struct OpenAiStreamChunk {
1966 choices: Vec<OpenAiDeltaChoice>,
1967}
1968
1969#[derive(Debug, Deserialize)]
1970struct OpenAiDeltaChoice {
1971 delta: OpenAiDelta,
1972 #[allow(dead_code)]
1973 finish_reason: Option<String>,
1974}
1975
1976#[derive(Debug, Deserialize)]
1977struct OpenAiDelta {
1978 content: Option<String>,
1979 #[serde(default)]
1980 tool_calls: Option<Vec<OpenAiToolCallDelta>>,
1981}
1982
1983#[derive(Debug, Deserialize)]
1984struct OpenAiToolCallDelta {
1985 id: Option<String>,
1986 #[serde(default)]
1987 function: Option<OpenAiFunctionDelta>,
1988}
1989
1990#[derive(Debug, Deserialize)]
1991struct OpenAiFunctionDelta {
1992 #[serde(default)]
1993 name: Option<String>,
1994 #[serde(default)]
1995 arguments: Option<String>,
1996}
1997
1998#[derive(Debug, Serialize, Deserialize)]
1999struct AnthropicMessage {
2000 role: String,
2001 content: Vec<AnthropicContentBlock>,
2002}
2003
2004#[derive(Debug, Serialize, Deserialize)]
2005struct AnthropicContentBlock {
2006 r#type: String,
2007 #[serde(skip_serializing_if = "Option::is_none")]
2008 text: Option<String>,
2009 #[serde(skip_serializing_if = "Option::is_none")]
2010 name: Option<String>,
2011 #[serde(skip_serializing_if = "Option::is_none")]
2012 input_schema: Option<Value>,
2013}
2014
2015#[derive(Debug, Serialize, Deserialize)]
2016struct AnthropicTool {
2017 name: String,
2018 description: String,
2019 input_schema: Value,
2020}
2021
2022#[derive(Debug, Deserialize)]
2023struct AnthropicResponse {
2024 content: Vec<AnthropicContentBlock>,
2025}
2026
2027#[derive(Debug, Deserialize)]
2028struct AnthropicStreamChunk {
2029 delta: AnthropicDelta,
2030}
2031
2032#[derive(Debug, Deserialize)]
2033struct AnthropicDelta {
2034 #[serde(default)]
2035 text: Option<String>,
2036}
2037
2038#[derive(Debug, Serialize, Deserialize)]
2039struct GeminiMessage {
2040 role: String,
2041 parts: Vec<GeminiPart>,
2042}
2043
2044#[derive(Debug, Serialize, Deserialize)]
2045struct GeminiPart {
2046 text: String,
2047}
2048
2049#[derive(Debug, Deserialize)]
2050struct GeminiResponse {
2051 candidates: Vec<GeminiCandidate>,
2052}
2053
2054#[derive(Debug, Deserialize)]
2055struct GeminiCandidate {
2056 content: GeminiCandidateContent,
2057}
2058
2059#[derive(Debug, Deserialize)]
2060struct GeminiCandidateContent {
2061 parts: Vec<GeminiPart>,
2062}
2063
2064#[derive(Debug, Serialize, Deserialize)]
2066struct CohereMessage {
2067 role: String,
2068 content: String,
2069}
2070
2071#[derive(Debug, Serialize, Deserialize)]
2072struct CohereTool {
2073 r#type: String,
2074 function: CohereFunction,
2075}
2076
2077#[derive(Debug, Serialize, Deserialize)]
2078struct CohereFunction {
2079 name: String,
2080 #[serde(skip_serializing_if = "Option::is_none")]
2081 description: Option<String>,
2082 #[serde(skip_serializing_if = "Option::is_none")]
2083 parameters: Option<Value>,
2084}
2085
2086#[derive(Debug, Deserialize)]
2087struct CohereResponse {
2088 #[serde(default)]
2089 message: Option<CohereResponseMessage>,
2090 #[serde(default)]
2091 tool_calls: Option<Vec<CohereToolCall>>,
2092}
2093
2094#[derive(Debug, Deserialize)]
2095struct CohereResponseMessage {
2096 #[serde(default)]
2097 content: Option<Value>,
2098}
2099
2100#[derive(Debug, Deserialize)]
2101struct CohereToolCall {
2102 #[serde(default)]
2103 id: Option<String>,
2104 function: CohereFunctionCall,
2105}
2106
2107#[derive(Debug, Deserialize)]
2108struct CohereFunctionCall {
2109 name: String,
2110 arguments: String,
2111}
2112
2113#[derive(Debug, Deserialize)]
2114struct CohereStreamChunk {
2115 #[serde(default)]
2116 delta: Option<CohereDelta>,
2117}
2118
2119#[derive(Debug, Deserialize)]
2120struct CohereDelta {
2121 #[serde(default)]
2122 message: Option<CohereResponseMessage>,
2123}
2124