use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use crate::error::ApiError;
use crate::msg::LlmEvent;
use crate::raw::shared::ToolDefinition;
use crate::types::CompleteResponse;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ImageContent {
pub data: ImageData,
pub mime_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ImageData {
Base64(String),
Url(String),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Content {
Text { text: String },
Image(ImageContent),
}
impl Content {
pub fn text(s: impl Into<String>) -> Self {
Content::Text { text: s.into() }
}
}
impl From<&str> for Content {
fn from(s: &str) -> Self {
Content::Text {
text: s.to_string(),
}
}
}
impl From<String> for Content {
fn from(s: String) -> Self {
Content::Text { text: s }
}
}
pub type UserContent = Content;
#[derive(Debug, Clone)]
pub enum Message {
User(Vec<UserContent>),
Assistant {
content: Option<String>,
reasoning: Option<String>,
tool_calls: Vec<ToolCall>,
provider_data: Option<serde_json::Value>,
},
ToolResult {
call_id: String,
content: Vec<Content>,
},
}
impl Message {
pub fn estimate_tokens(&self) -> usize {
use std::sync::OnceLock;
static BPE: OnceLock<tiktoken_rs::CoreBPE> = OnceLock::new();
let bpe = BPE.get_or_init(|| tiktoken_rs::cl100k_base().unwrap());
let mut tokens = 0;
match self {
Message::User(parts) => {
tokens += 4; for part in parts {
match part {
UserContent::Text { text: t } => {
tokens += bpe.encode_with_special_tokens(t).len()
}
UserContent::Image(_) => tokens += 1000, }
}
}
Message::Assistant {
content,
reasoning,
tool_calls,
..
} => {
tokens += 4;
if let Some(c) = content {
tokens += bpe.encode_with_special_tokens(c).len();
}
if let Some(r) = reasoning {
tokens += bpe.encode_with_special_tokens(r).len();
}
for tc in tool_calls {
tokens += bpe.encode_with_special_tokens(&tc.name).len();
tokens += bpe.encode_with_special_tokens(&tc.arguments).len();
}
}
Message::ToolResult { content, .. } => {
tokens += 4;
for part in content {
match part {
Content::Text { text } => {
tokens += bpe.encode_with_special_tokens(text).len()
}
Content::Image(_) => tokens += 1000,
}
}
}
}
tokens
}
}
pub fn truncate_to_token_budget(history: &mut Vec<Message>, budget: usize) {
let mut acc: usize = 0;
let mut keep_from = history.len(); for (i, msg) in history.iter().enumerate().rev() {
acc += msg.estimate_tokens();
if acc > budget {
keep_from = (i + 1).min(history.len() - 1);
break;
}
}
if keep_from == history.len() {
return;
}
while keep_from < history.len() {
match &history[keep_from] {
Message::ToolResult { .. } => keep_from += 1,
_ => break,
}
}
if keep_from < history.len()
&& let Message::Assistant { tool_calls, .. } = &history[keep_from]
&& !tool_calls.is_empty()
{
let ids: std::collections::HashSet<&str> =
tool_calls.iter().map(|tc| tc.id.as_str()).collect();
keep_from += 1;
while keep_from < history.len() {
match &history[keep_from] {
Message::ToolResult { call_id, .. } if ids.contains(call_id.as_str()) => {
keep_from += 1;
}
_ => break,
}
}
}
if keep_from >= history.len() {
keep_from = history.len().saturating_sub(1);
}
if keep_from > 0 {
history.drain(0..keep_from);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Provider {
#[serde(rename = "deepseek")]
DeepSeek,
#[serde(rename = "openai")]
OpenAI,
#[serde(rename = "anthropic")]
Anthropic,
#[serde(rename = "gemini")]
Gemini,
#[serde(rename = "kimi")]
Kimi,
#[serde(rename = "glm")]
Glm,
#[serde(rename = "minimax")]
Minimax,
#[serde(rename = "grok")]
Grok,
#[serde(rename = "openrouter")]
OpenRouter,
#[cfg(feature = "claude-code")]
#[serde(rename = "claude-code")]
ClaudeCode,
}
impl Provider {
pub fn default_base_url(&self) -> &'static str {
match self {
Provider::DeepSeek => "https://api.deepseek.com",
Provider::OpenAI => "https://api.openai.com/v1",
Provider::Anthropic => "https://api.anthropic.com",
Provider::Gemini => "https://generativelanguage.googleapis.com/v1beta",
Provider::Kimi => "https://api.moonshot.cn/v1",
Provider::Glm => "https://open.bigmodel.cn/api/paas/v4",
Provider::Minimax => "https://api.minimaxi.com/anthropic",
Provider::Grok => "https://api.x.ai/v1",
Provider::OpenRouter => "https://openrouter.ai/api/v1",
#[cfg(feature = "claude-code")]
Provider::ClaudeCode => "",
}
}
pub fn default_model(&self) -> &'static str {
match self {
Provider::DeepSeek => "deepseek-chat",
Provider::OpenAI => "gpt-4o",
Provider::Anthropic => "claude-sonnet-4-20250514",
Provider::Gemini => "gemini-2.0-flash",
Provider::Kimi => "kimi-k2.5",
Provider::Glm => "glm-5",
Provider::Minimax => "MiniMax-M2.7",
Provider::Grok => "grok-4",
Provider::OpenRouter => "openrouter/auto",
#[cfg(feature = "claude-code")]
Provider::ClaudeCode => "sonnet",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ReasoningEffort {
None,
Minimal,
Low,
Medium,
High,
#[serde(rename = "xhigh")]
XHigh,
Max,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolChoice {
#[default]
Auto,
None,
Required,
Tool(String),
}
#[derive(Debug, Clone)]
pub struct Request {
pub provider: Provider,
pub api_key: String,
pub base_url: String,
pub model: String,
pub system_message: Option<String>,
pub messages: Vec<Message>,
pub reminder: Option<String>,
pub tools: Vec<ToolDefinition>,
pub tool_choice: Option<ToolChoice>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub reasoning_effort: Option<ReasoningEffort>,
pub response_format: Option<ResponseFormat>,
pub extra_body: serde_json::Map<String, serde_json::Value>,
pub max_retries: u32,
pub retry_delay_ms: u64,
}
impl Request {
pub fn new(provider: Provider, api_key: impl Into<String>) -> Self {
Self {
base_url: provider.default_base_url().to_string(),
model: provider.default_model().to_string(),
api_key: api_key.into(),
provider,
system_message: None,
messages: Vec::new(),
reminder: None,
tools: Vec::new(),
tool_choice: None,
temperature: None,
reasoning_effort: None,
max_tokens: None,
response_format: None,
extra_body: serde_json::Map::new(),
max_retries: 3,
retry_delay_ms: 1000,
}
}
pub fn deepseek(api_key: impl Into<String>) -> Self {
Self::new(Provider::DeepSeek, api_key)
}
pub fn openai(api_key: impl Into<String>) -> Self {
Self::new(Provider::OpenAI, api_key)
}
pub fn anthropic(api_key: impl Into<String>) -> Self {
Self::new(Provider::Anthropic, api_key)
}
pub fn gemini(api_key: impl Into<String>) -> Self {
Self::new(Provider::Gemini, api_key)
}
pub fn kimi(api_key: impl Into<String>) -> Self {
Self::new(Provider::Kimi, api_key)
}
pub fn glm(api_key: impl Into<String>) -> Self {
Self::new(Provider::Glm, api_key)
}
pub fn minimax(api_key: impl Into<String>) -> Self {
Self::new(Provider::Minimax, api_key)
}
pub fn grok(api_key: impl Into<String>) -> Self {
Self::new(Provider::Grok, api_key)
}
pub fn openrouter(api_key: impl Into<String>) -> Self {
Self::new(Provider::OpenRouter, api_key)
}
#[cfg(feature = "claude-code")]
pub fn claude_code() -> Self {
Self::new(Provider::ClaudeCode, String::new())
}
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
pub fn model(mut self, m: impl Into<String>) -> Self {
self.model = m.into();
self
}
pub fn system_prompt(mut self, p: impl Into<String>) -> Self {
self.system_message = Some(p.into());
self
}
pub fn reminder(mut self, p: impl Into<String>) -> Self {
self.reminder = Some(p.into());
self
}
pub fn message(mut self, m: Message) -> Self {
self.messages.push(m);
self
}
pub fn user(self, text: impl Into<String>) -> Self {
self.message(Message::User(vec![Content::text(text)]))
}
pub fn messages(mut self, msgs: Vec<Message>) -> Self {
self.messages = msgs;
self
}
pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
self.tools = tools;
self
}
pub fn temperature(mut self, t: f32) -> Self {
self.temperature = Some(t);
self
}
pub fn max_tokens(mut self, n: u32) -> Self {
self.max_tokens = Some(n);
self
}
pub fn reasoning_effort(mut self, e: ReasoningEffort) -> Self {
self.reasoning_effort = Some(e);
self
}
pub fn text(mut self) -> Self {
self.response_format = Some(ResponseFormat::Text);
self
}
pub fn json_schema(
mut self,
name: impl Into<String>,
schema: serde_json::Value,
strict: bool,
) -> Self {
self.response_format = Some(ResponseFormat::JsonSchema {
name: name.into(),
schema,
strict,
});
self
}
pub fn json(mut self) -> Self {
self.response_format = Some(ResponseFormat::JsonObject);
self
}
pub fn retries(mut self, max: u32, initial_delay_ms: u64) -> Self {
self.max_retries = max;
self.retry_delay_ms = initial_delay_ms;
self
}
pub fn extra_body(mut self, extra: serde_json::Map<String, serde_json::Value>) -> Self {
self.extra_body = extra;
self
}
pub fn effective_base_url(&self) -> &str {
if self.base_url.is_empty() {
self.provider.default_base_url()
} else {
&self.base_url
}
}
pub async fn stream(
&self,
http: &reqwest::Client,
) -> Result<BoxStream<'static, LlmEvent>, ApiError> {
let config = self.to_agent_config();
let messages = &self.messages;
let tools = &self.tools;
match self.provider {
Provider::DeepSeek => {
crate::raw::deepseek::stream_deepseek(&self.api_key, http, &config, messages, tools)
.await
}
Provider::OpenAI => {
crate::raw::openai::stream_openai(&self.api_key, http, &config, messages, tools)
.await
}
Provider::Anthropic => {
crate::raw::anthropic::stream_anthropic(
&self.api_key,
http,
&config,
messages,
tools,
)
.await
}
Provider::Gemini => {
crate::raw::gemini::stream_gemini(&self.api_key, http, &config, messages, tools)
.await
}
Provider::Minimax => {
crate::raw::anthropic::stream_anthropic(
&self.api_key,
http,
&config,
messages,
tools,
)
.await
}
Provider::Kimi => {
crate::raw::kimi::stream_kimi(&self.api_key, http, &config, messages, tools).await
}
Provider::Glm => {
crate::raw::glm::stream_glm(&self.api_key, http, &config, messages, tools).await
}
Provider::Grok => {
crate::raw::grok::stream_grok(&self.api_key, http, &config, messages, tools).await
}
Provider::OpenRouter => {
crate::raw::openrouter::stream_openrouter(
&self.api_key,
http,
&config,
messages,
tools,
)
.await
}
#[cfg(feature = "claude-code")]
Provider::ClaudeCode => {
crate::raw::claude_code::stream_claude_code(
&self.api_key,
http,
&config,
messages,
tools,
)
.await
}
}
}
pub async fn complete(&self, http: &reqwest::Client) -> Result<CompleteResponse, ApiError> {
let config = self.to_agent_config();
let messages = &self.messages;
let tools = &self.tools;
match self.provider {
Provider::DeepSeek => {
crate::raw::deepseek::complete_deepseek(
&self.api_key,
http,
&config,
messages,
tools,
)
.await
}
Provider::OpenAI => {
crate::raw::openai::complete_openai(&self.api_key, http, &config, messages, tools)
.await
}
Provider::Anthropic => {
crate::raw::anthropic::complete_anthropic(
&self.api_key,
http,
&config,
messages,
tools,
)
.await
}
Provider::Gemini => {
crate::raw::gemini::complete_gemini(&self.api_key, http, &config, messages, tools)
.await
}
Provider::Minimax => {
crate::raw::anthropic::complete_anthropic(
&self.api_key,
http,
&config,
messages,
tools,
)
.await
}
Provider::Kimi => {
crate::raw::kimi::complete_kimi(&self.api_key, http, &config, messages, tools).await
}
Provider::Glm => {
crate::raw::glm::complete_glm(&self.api_key, http, &config, messages, tools).await
}
Provider::Grok => {
crate::raw::grok::complete_grok(&self.api_key, http, &config, messages, tools).await
}
Provider::OpenRouter => {
crate::raw::openrouter::complete_openrouter(
&self.api_key,
http,
&config,
messages,
tools,
)
.await
}
#[cfg(feature = "claude-code")]
Provider::ClaudeCode => {
crate::raw::claude_code::complete_claude_code(
&self.api_key,
http,
&config,
messages,
tools,
)
.await
}
}
}
fn to_agent_config(&self) -> crate::config::AgentConfig {
crate::config::AgentConfig {
base_url: self.effective_base_url().to_string(),
model: self.model.clone(),
system_prompt: self.system_message.clone(),
reminder: self.reminder.clone(),
max_tokens: self.max_tokens,
temperature: self.temperature,
reasoning_effort: self.reasoning_effort,
extra_body: self.extra_body.clone(),
response_format: self.response_format.clone(),
max_retries: self.max_retries,
retry_delay_ms: self.retry_delay_ms,
}
}
}
#[cfg(test)]
mod truncate_tests {
use super::*;
fn user(s: &str) -> Message {
Message::User(vec![crate::UserContent::Text {
text: s.repeat(200),
}])
}
fn assistant_text(s: &str) -> Message {
Message::Assistant {
content: Some(s.repeat(200)),
reasoning: None,
tool_calls: vec![],
provider_data: None,
}
}
fn assistant_tc(ids: &[&str]) -> Message {
Message::Assistant {
content: None,
reasoning: None,
tool_calls: ids
.iter()
.map(|id| ToolCall {
id: id.to_string(),
name: "bash".to_string(),
arguments: "{}".to_string(),
})
.collect(),
provider_data: None,
}
}
fn tool_result(id: &str) -> Message {
Message::ToolResult {
call_id: id.to_string(),
content: vec![Content::text("ok")],
}
}
fn no_orphans(history: &[Message]) {
use std::collections::HashSet;
let called: HashSet<&str> = history
.iter()
.filter_map(|m| {
if let Message::Assistant { tool_calls, .. } = m {
Some(tool_calls.iter().map(|tc| tc.id.as_str()))
} else {
None
}
})
.flatten()
.collect();
for m in history {
if let Message::ToolResult { call_id, .. } = m {
assert!(
called.contains(call_id.as_str()),
"orphaned ToolResult with call_id={call_id}"
);
}
}
}
#[test]
fn test_no_truncation_needed() {
let mut h = vec![user("a"), assistant_text("b")];
truncate_to_token_budget(&mut h, 1_000_000);
assert_eq!(h.len(), 2);
}
#[test]
fn test_orphaned_tool_results_skipped_at_start() {
let mut h = vec![
user("x"),
assistant_tc(&["id1"]),
tool_result("id1"),
user("y"),
assistant_text("z"),
];
let budget = h[3..].iter().map(|m| m.estimate_tokens()).sum::<usize>() + 10;
truncate_to_token_budget(&mut h, budget);
no_orphans(&h);
assert!(!matches!(h.first(), Some(Message::ToolResult { .. })));
no_orphans(&h);
}
#[test]
fn test_assistant_with_tool_calls_not_split_from_results() {
let mut h = vec![
user("old"),
assistant_tc(&["a1", "a2"]),
tool_result("a1"),
tool_result("a2"),
user("new"),
assistant_text("reply"),
];
let budget = h[4..].iter().map(|m| m.estimate_tokens()).sum::<usize>() + 10;
truncate_to_token_budget(&mut h, budget);
no_orphans(&h);
assert!(!h.iter().any(|m| matches!(m, Message::ToolResult { .. })));
}
#[test]
fn test_always_keeps_at_least_one_message() {
let mut h = vec![user("only")];
truncate_to_token_budget(&mut h, 1);
assert_eq!(h.len(), 1);
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseFormat {
#[default]
Text,
#[serde(rename = "json_object")]
JsonObject,
JsonSchema {
name: String,
schema: serde_json::Value,
strict: bool,
},
}