1use futures::stream::BoxStream;
28use serde::{Deserialize, Serialize};
29
30use crate::error::ApiError;
31use crate::msg::LlmEvent;
32use crate::raw::shared::ToolDefinition;
33use crate::types::CompleteResponse;
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39pub struct ImageContent {
40 pub data: ImageData,
42 pub mime_type: String,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
48#[serde(rename_all = "snake_case")]
49pub enum ImageData {
50 Base64(String),
52 Url(String),
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58#[serde(tag = "type", rename_all = "snake_case")]
59pub enum Content {
60 Text { text: String },
61 Image(ImageContent),
62}
63
64impl Content {
65 pub fn text(s: impl Into<String>) -> Self {
66 Content::Text { text: s.into() }
67 }
68}
69
70impl From<&str> for Content {
71 fn from(s: &str) -> Self {
72 Content::Text {
73 text: s.to_string(),
74 }
75 }
76}
77impl From<String> for Content {
78 fn from(s: String) -> Self {
79 Content::Text { text: s }
80 }
81}
82
83pub type UserContent = Content;
85
86#[derive(Debug, Clone)]
89pub enum Message {
90 User(Vec<UserContent>),
92
93 Assistant {
96 content: Option<String>,
97 reasoning: Option<String>,
99 tool_calls: Vec<ToolCall>,
100 },
101
102 ToolResult {
104 call_id: String,
105 content: Vec<Content>,
106 },
107}
108
109impl Message {
110 pub fn estimate_tokens(&self) -> usize {
115 use std::sync::OnceLock;
116 static BPE: OnceLock<tiktoken_rs::CoreBPE> = OnceLock::new();
117 let bpe = BPE.get_or_init(|| tiktoken_rs::cl100k_base().unwrap());
118 let mut tokens = 0;
119
120 match self {
121 Message::User(parts) => {
122 tokens += 4; for part in parts {
124 match part {
125 UserContent::Text { text: t } => {
126 tokens += bpe.encode_with_special_tokens(t).len()
127 }
128 UserContent::Image(_) => tokens += 1000, }
130 }
131 }
132 Message::Assistant {
133 content,
134 reasoning,
135 tool_calls,
136 } => {
137 tokens += 4;
138 if let Some(c) = content {
139 tokens += bpe.encode_with_special_tokens(c).len();
140 }
141 if let Some(r) = reasoning {
142 tokens += bpe.encode_with_special_tokens(r).len();
143 }
144 for tc in tool_calls {
145 tokens += bpe.encode_with_special_tokens(&tc.name).len();
146 tokens += bpe.encode_with_special_tokens(&tc.arguments).len();
147 }
148 }
149 Message::ToolResult { content, .. } => {
150 tokens += 4;
151 for part in content {
152 match part {
153 Content::Text { text } => {
154 tokens += bpe.encode_with_special_tokens(text).len()
155 }
156 Content::Image(_) => tokens += 1000,
157 }
158 }
159 }
160 }
161 tokens
162 }
163}
164
165pub fn truncate_to_token_budget(history: &mut Vec<Message>, budget: usize) {
172 let mut acc: usize = 0;
174 let mut keep_from = history.len(); for (i, msg) in history.iter().enumerate().rev() {
176 acc += msg.estimate_tokens();
177 if acc > budget {
178 keep_from = (i + 1).min(history.len() - 1);
179 break;
180 }
181 }
182
183 if keep_from == history.len() {
185 return;
186 }
187
188 while keep_from < history.len() {
191 match &history[keep_from] {
192 Message::ToolResult { .. } => keep_from += 1,
193 _ => break,
194 }
195 }
196
197 if keep_from < history.len()
200 && let Message::Assistant { tool_calls, .. } = &history[keep_from]
201 && !tool_calls.is_empty() {
202 let ids: std::collections::HashSet<&str> =
204 tool_calls.iter().map(|tc| tc.id.as_str()).collect();
205 keep_from += 1;
206 while keep_from < history.len() {
208 match &history[keep_from] {
209 Message::ToolResult { call_id, .. } if ids.contains(call_id.as_str()) => {
210 keep_from += 1;
211 }
212 _ => break,
213 }
214 }
215 }
216
217 if keep_from >= history.len() {
219 keep_from = history.len().saturating_sub(1);
220 }
221
222 if keep_from > 0 {
223 history.drain(0..keep_from);
224 }
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct ToolCall {
230 pub id: String,
232 pub name: String,
234 pub arguments: String,
236}
237
238#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
245pub enum Provider {
246 #[serde(rename = "deepseek")]
247 DeepSeek,
248 #[serde(rename = "openai")]
249 OpenAI,
250 #[serde(rename = "anthropic")]
251 Anthropic,
252 #[serde(rename = "gemini")]
253 Gemini,
254 #[serde(rename = "kimi")]
256 Kimi,
257 #[serde(rename = "glm")]
259 Glm,
260 #[serde(rename = "minimax")]
262 Minimax,
263 #[serde(rename = "grok")]
265 Grok,
266 #[serde(rename = "openrouter")]
268 OpenRouter,
269}
270
271impl Provider {
272 pub fn default_base_url(&self) -> &'static str {
274 match self {
275 Provider::DeepSeek => "https://api.deepseek.com",
276 Provider::OpenAI => "https://api.openai.com/v1",
277 Provider::Anthropic => "https://api.anthropic.com",
278 Provider::Gemini => "https://generativelanguage.googleapis.com/v1beta",
279 Provider::Kimi => "https://api.moonshot.cn/v1",
280 Provider::Glm => "https://open.bigmodel.cn/api/paas/v4",
281 Provider::Minimax => "https://api.minimaxi.com/anthropic",
282 Provider::Grok => "https://api.x.ai/v1",
283 Provider::OpenRouter => "https://openrouter.ai/api/v1",
284 }
285 }
286
287 pub fn default_model(&self) -> &'static str {
289 match self {
290 Provider::DeepSeek => "deepseek-chat",
291 Provider::OpenAI => "gpt-4o",
292 Provider::Anthropic => "claude-sonnet-4-20250514",
293 Provider::Gemini => "gemini-2.0-flash",
294 Provider::Kimi => "kimi-k2.5",
295 Provider::Glm => "glm-5",
296 Provider::Minimax => "MiniMax-M2.7",
297 Provider::Grok => "grok-4",
298 Provider::OpenRouter => "openrouter/auto",
299 }
300 }
301}
302
303#[derive(Debug, Clone, Default, Serialize, Deserialize)]
307#[serde(rename_all = "snake_case")]
308pub enum ToolChoice {
309 #[default]
311 Auto,
312 None,
314 Required,
316 Tool(String),
318}
319
320#[derive(Debug, Clone)]
330pub struct Request {
331 pub provider: Provider,
334 pub api_key: String,
336 pub base_url: String,
338
339 pub model: String,
342 pub system_message: Option<String>,
344 pub messages: Vec<Message>,
346
347 pub tools: Vec<ToolDefinition>,
350 pub tool_choice: Option<ToolChoice>,
352
353 pub temperature: Option<f32>,
356 pub max_tokens: Option<u32>,
358 pub response_format: Option<ResponseFormat>,
360 pub extra_body: serde_json::Map<String, serde_json::Value>,
363
364 pub max_retries: u32,
367 pub retry_delay_ms: u64,
369}
370
371impl Request {
372 pub fn new(provider: Provider, api_key: impl Into<String>) -> Self {
377 Self {
378 base_url: provider.default_base_url().to_string(),
379 model: provider.default_model().to_string(),
380 api_key: api_key.into(),
381 provider,
382 system_message: None,
383 messages: Vec::new(),
384 tools: Vec::new(),
385 tool_choice: None,
386 temperature: None,
387 max_tokens: None,
388 response_format: None,
389 extra_body: serde_json::Map::new(),
390 max_retries: 3,
391 retry_delay_ms: 1000,
392 }
393 }
394
395 pub fn deepseek(api_key: impl Into<String>) -> Self {
397 Self::new(Provider::DeepSeek, api_key)
398 }
399
400 pub fn openai(api_key: impl Into<String>) -> Self {
402 Self::new(Provider::OpenAI, api_key)
403 }
404
405 pub fn anthropic(api_key: impl Into<String>) -> Self {
407 Self::new(Provider::Anthropic, api_key)
408 }
409
410 pub fn gemini(api_key: impl Into<String>) -> Self {
412 Self::new(Provider::Gemini, api_key)
413 }
414
415 pub fn kimi(api_key: impl Into<String>) -> Self {
417 Self::new(Provider::Kimi, api_key)
418 }
419
420 pub fn glm(api_key: impl Into<String>) -> Self {
422 Self::new(Provider::Glm, api_key)
423 }
424
425 pub fn minimax(api_key: impl Into<String>) -> Self {
427 Self::new(Provider::Minimax, api_key)
428 }
429
430 pub fn grok(api_key: impl Into<String>) -> Self {
432 Self::new(Provider::Grok, api_key)
433 }
434
435 pub fn openrouter(api_key: impl Into<String>) -> Self {
437 Self::new(Provider::OpenRouter, api_key)
438 }
439
440 pub fn base_url(mut self, url: impl Into<String>) -> Self {
444 self.base_url = url.into();
445 self
446 }
447
448 pub fn model(mut self, m: impl Into<String>) -> Self {
450 self.model = m.into();
451 self
452 }
453
454 pub fn system_prompt(mut self, p: impl Into<String>) -> Self {
456 self.system_message = Some(p.into());
457 self
458 }
459
460 pub fn message(mut self, m: Message) -> Self {
462 self.messages.push(m);
463 self
464 }
465
466 pub fn user(self, text: impl Into<String>) -> Self {
468 self.message(Message::User(vec![Content::text(text)]))
469 }
470
471 pub fn messages(mut self, msgs: Vec<Message>) -> Self {
473 self.messages = msgs;
474 self
475 }
476
477 pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
479 self.tools = tools;
480 self
481 }
482
483 pub fn temperature(mut self, t: f32) -> Self {
485 self.temperature = Some(t);
486 self
487 }
488
489 pub fn max_tokens(mut self, n: u32) -> Self {
491 self.max_tokens = Some(n);
492 self
493 }
494
495 pub fn text(mut self) -> Self {
497 self.response_format = Some(ResponseFormat::Text);
498 self
499 }
500
501 pub fn json_schema(
509 mut self,
510 name: impl Into<String>,
511 schema: serde_json::Value,
512 strict: bool,
513 ) -> Self {
514 self.response_format = Some(ResponseFormat::JsonSchema {
515 name: name.into(),
516 schema,
517 strict,
518 });
519 self
520 }
521
522 pub fn json(mut self) -> Self {
528 self.response_format = Some(ResponseFormat::JsonObject);
529 self
530 }
531
532 pub fn retries(mut self, max: u32, initial_delay_ms: u64) -> Self {
534 self.max_retries = max;
535 self.retry_delay_ms = initial_delay_ms;
536 self
537 }
538
539 pub fn extra_body(mut self, extra: serde_json::Map<String, serde_json::Value>) -> Self {
541 self.extra_body = extra;
542 self
543 }
544
545 pub fn effective_base_url(&self) -> &str {
549 if self.base_url.is_empty() {
550 self.provider.default_base_url()
551 } else {
552 &self.base_url
553 }
554 }
555
556 pub async fn stream(
560 &self,
561 http: &reqwest::Client,
562 ) -> Result<BoxStream<'static, LlmEvent>, ApiError> {
563 let config = self.to_agent_config();
564 let messages = &self.messages;
565 let tools = &self.tools;
566
567 match self.provider {
568 Provider::DeepSeek => {
569 use crate::raw::deepseek::prepare_history;
570 use crate::raw::openai::stream_openai_compatible;
571 let config = degrade_json_schema_for_deepseek(config);
572 stream_openai_compatible(
573 &self.api_key,
574 http,
575 &config,
576 messages,
577 tools,
578 Some(prepare_history),
579 )
580 .await
581 }
582 Provider::OpenAI => {
583 use crate::raw::openai::stream_openai_compatible;
584 stream_openai_compatible(&self.api_key, http, &config, messages, tools, None).await
585 }
586 Provider::Anthropic => {
587 crate::raw::anthropic::stream_anthropic(
588 &self.api_key,
589 http,
590 &config,
591 messages,
592 tools,
593 )
594 .await
595 }
596 Provider::Gemini => {
597 crate::raw::gemini::stream_gemini(&self.api_key, http, &config, messages, tools)
598 .await
599 }
600 Provider::Minimax => {
601 crate::raw::anthropic::stream_anthropic(
602 &self.api_key,
603 http,
604 &config,
605 messages,
606 tools,
607 )
608 .await
609 }
610 Provider::Kimi | Provider::Glm | Provider::Grok => {
611 use crate::raw::openai::stream_openai_compatible;
612 stream_openai_compatible(&self.api_key, http, &config, messages, tools, None).await
613 }
614 Provider::OpenRouter => {
615 crate::raw::openrouter::stream_openrouter(
616 &self.api_key,
617 http,
618 &config,
619 messages,
620 tools,
621 )
622 .await
623 }
624 }
625 }
626
627 pub async fn complete(&self, http: &reqwest::Client) -> Result<CompleteResponse, ApiError> {
629 let config = self.to_agent_config();
630 let messages = &self.messages;
631 let tools = &self.tools;
632
633 match self.provider {
634 Provider::DeepSeek => {
635 use crate::raw::deepseek::prepare_history;
636 use crate::raw::openai::complete_openai_compatible;
637 let config = degrade_json_schema_for_deepseek(config);
638 complete_openai_compatible(
639 &self.api_key,
640 http,
641 &config,
642 messages,
643 tools,
644 Some(prepare_history),
645 )
646 .await
647 }
648 Provider::OpenAI => {
649 use crate::raw::openai::complete_openai_compatible;
650 complete_openai_compatible(&self.api_key, http, &config, messages, tools, None)
651 .await
652 }
653 Provider::Anthropic => {
654 crate::raw::anthropic::complete_anthropic(
655 &self.api_key,
656 http,
657 &config,
658 messages,
659 tools,
660 )
661 .await
662 }
663 Provider::Gemini => {
664 crate::raw::gemini::complete_gemini(&self.api_key, http, &config, messages, tools)
665 .await
666 }
667 Provider::Minimax => {
668 crate::raw::anthropic::complete_anthropic(
669 &self.api_key,
670 http,
671 &config,
672 messages,
673 tools,
674 )
675 .await
676 }
677 Provider::Kimi | Provider::Glm | Provider::Grok => {
678 use crate::raw::openai::complete_openai_compatible;
679 complete_openai_compatible(&self.api_key, http, &config, messages, tools, None)
680 .await
681 }
682 Provider::OpenRouter => {
683 crate::raw::openrouter::complete_openrouter(
684 &self.api_key,
685 http,
686 &config,
687 messages,
688 tools,
689 )
690 .await
691 }
692 }
693 }
694
695 fn to_agent_config(&self) -> crate::config::AgentConfig {
699 crate::config::AgentConfig {
700 base_url: self.effective_base_url().to_string(),
701 model: self.model.clone(),
702 system_prompt: self.system_message.clone(),
703 max_tokens: self.max_tokens,
704 temperature: self.temperature,
705 extra_body: self.extra_body.clone(),
706 response_format: self.response_format.clone(),
707 max_retries: self.max_retries,
708 retry_delay_ms: self.retry_delay_ms,
709 }
710 }
711}
712
713fn degrade_json_schema_for_deepseek(
716 mut config: crate::config::AgentConfig,
717) -> crate::config::AgentConfig {
718 if matches!(
719 config.response_format,
720 Some(ResponseFormat::JsonSchema { .. })
721 ) {
722 tracing::warn!("DeepSeek does not support json_schema; degrading to json_object");
723 config.response_format = Some(ResponseFormat::JsonObject);
724 }
725 config
726}
727
728#[cfg(test)]
729mod truncate_tests {
730 use super::*;
731
732 fn user(s: &str) -> Message {
733 Message::User(vec![crate::UserContent::Text {
734 text: s.repeat(200),
735 }])
736 }
737 fn assistant_text(s: &str) -> Message {
738 Message::Assistant {
739 content: Some(s.repeat(200)),
740 reasoning: None,
741 tool_calls: vec![],
742 }
743 }
744 fn assistant_tc(ids: &[&str]) -> Message {
745 Message::Assistant {
746 content: None,
747 reasoning: None,
748 tool_calls: ids
749 .iter()
750 .map(|id| ToolCall {
751 id: id.to_string(),
752 name: "bash".to_string(),
753 arguments: "{}".to_string(),
754 })
755 .collect(),
756 }
757 }
758 fn tool_result(id: &str) -> Message {
759 Message::ToolResult {
760 call_id: id.to_string(),
761 content: vec![Content::text("ok")],
762 }
763 }
764
765 fn no_orphans(history: &[Message]) {
766 use std::collections::HashSet;
767 let called: HashSet<&str> = history
768 .iter()
769 .filter_map(|m| {
770 if let Message::Assistant { tool_calls, .. } = m {
771 Some(tool_calls.iter().map(|tc| tc.id.as_str()))
772 } else {
773 None
774 }
775 })
776 .flatten()
777 .collect();
778
779 for m in history {
780 if let Message::ToolResult { call_id, .. } = m {
781 assert!(
782 called.contains(call_id.as_str()),
783 "orphaned ToolResult with call_id={call_id}"
784 );
785 }
786 }
787 }
788
789 #[test]
790 fn test_no_truncation_needed() {
791 let mut h = vec![user("a"), assistant_text("b")];
792 truncate_to_token_budget(&mut h, 1_000_000);
793 assert_eq!(h.len(), 2);
794 }
795
796 #[test]
797 fn test_orphaned_tool_results_skipped_at_start() {
798 let mut h = vec![
800 user("x"),
801 assistant_tc(&["id1"]),
802 tool_result("id1"),
803 user("y"),
804 assistant_text("z"),
805 ];
806 let budget = h[3..].iter().map(|m| m.estimate_tokens()).sum::<usize>() + 10;
807 truncate_to_token_budget(&mut h, budget);
808 no_orphans(&h);
809 assert!(
811 h.iter().all(|m| !matches!(m, Message::ToolResult { .. })
812 || matches!(m, Message::ToolResult { .. }))
813 );
814 no_orphans(&h);
815 }
816
817 #[test]
818 fn test_assistant_with_tool_calls_not_split_from_results() {
819 let mut h = vec![
821 user("old"),
822 assistant_tc(&["a1", "a2"]),
823 tool_result("a1"),
824 tool_result("a2"),
825 user("new"),
826 assistant_text("reply"),
827 ];
828 let budget = h[4..].iter().map(|m| m.estimate_tokens()).sum::<usize>() + 10;
830 truncate_to_token_budget(&mut h, budget);
831 no_orphans(&h);
832 assert!(!h.iter().any(|m| matches!(m, Message::ToolResult { .. })));
834 }
835
836 #[test]
837 fn test_always_keeps_at_least_one_message() {
838 let mut h = vec![user("only")];
839 truncate_to_token_budget(&mut h, 1);
840 assert_eq!(h.len(), 1);
841 }
842}
843
844#[derive(Debug, Clone, Default, Serialize, Deserialize)]
846#[serde(rename_all = "snake_case")]
847pub enum ResponseFormat {
848 #[default]
849 Text,
850 #[serde(rename = "json_object")]
851 JsonObject,
852 JsonSchema {
855 name: String,
857 schema: serde_json::Value,
859 strict: bool,
861 },
862}