use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
#[serde(default)]
pub arguments: serde_json::Value,
}
impl ToolCall {
pub fn to_api_dict(&self) -> serde_json::Value {
serde_json::json!({
"id": self.id,
"type": "function",
"function": {
"name": self.name,
"arguments": serde_json::to_string(&self.arguments).unwrap_or_default(),
},
})
}
pub fn from_api_dict(d: &serde_json::Value) -> Self {
let func = &d["function"];
let args = func["arguments"]
.as_str()
.and_then(|s| serde_json::from_str(s).ok())
.unwrap_or(serde_json::Value::Object(Default::default()));
Self {
id: d["id"].as_str().unwrap_or("").to_string(),
name: func["name"].as_str().unwrap_or("").to_string(),
arguments: args,
}
}
}
#[derive(Debug, Clone)]
pub struct Message {
pub role: String,
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub tool_call_id: Option<String>,
pub name: Option<String>,
}
impl Usage {
pub fn estimated_cost(&self, model: &str) -> f64 {
let (input_rate, output_rate) = if model.contains("deepseek") {
if model.contains("flash") || model.contains("free") {
(0.0, 0.0) } else {
(0.14e-6, 0.28e-6)
}
} else if model.contains("llama") || model.contains("nemotron") {
(0.10e-6, 0.20e-6)
} else if model.contains("gpt-4o") {
(2.50e-6, 10.0e-6)
} else if model.contains("gpt-4") || model.contains("claude") {
(10.0e-6, 30.0e-6)
} else if model.contains("mixtral") || model.contains("gemma") {
(0.60e-6, 0.60e-6)
} else if model.contains("gemini") {
(0.50e-6, 1.50e-6)
} else if model.contains("mistral") {
(2.0e-6, 6.0e-6)
} else {
(0.50e-6, 1.50e-6)
};
self.prompt_tokens as f64 * input_rate + self.completion_tokens as f64 * output_rate
}
}
#[derive(Debug, Clone, Default)]
pub struct Usage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
pub session_prompt: u64,
pub session_completion: u64,
pub session_total: u64,
pub turn_prompt: u64,
pub turn_completion: u64,
pub turn_total: u64,
}
impl Usage {
pub fn record_turn(&mut self, prompt: u64, completion: u64) {
self.turn_prompt = prompt;
self.turn_completion = completion;
self.turn_total = prompt + completion;
self.prompt_tokens += prompt;
self.completion_tokens += completion;
self.total_tokens += prompt + completion;
self.session_prompt = self.prompt_tokens;
self.session_completion = self.completion_tokens;
self.session_total = self.total_tokens;
}
#[allow(dead_code)]
pub fn turn_cost(&self, model: &str) -> f64 {
let usage = Usage {
prompt_tokens: self.turn_prompt,
completion_tokens: self.turn_completion,
total_tokens: self.turn_total,
..Default::default()
};
usage.estimated_cost(model)
}
pub fn session_cost(&self, model: &str) -> f64 {
self.estimated_cost(model)
}
}
impl Message {
pub fn new_system(content: impl Into<String>) -> Self {
Self {
role: "system".into(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn new_user(content: impl Into<String>) -> Self {
Self {
role: "user".into(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn new_assistant(content: Option<String>, tool_calls: Option<Vec<ToolCall>>) -> Self {
Self {
role: "assistant".into(),
content,
tool_calls,
tool_call_id: None,
name: None,
}
}
pub fn new_tool(content: impl Into<String>, tool_call_id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
role: "tool".into(),
content: Some(content.into()),
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
name: Some(name.into()),
}
}
pub fn to_api_dict(&self) -> serde_json::Value {
let mut d = serde_json::json!({
"role": self.role,
});
if let Some(ref content) = self.content {
d["content"] = serde_json::Value::String(content.clone());
} else {
d["content"] = serde_json::Value::Null;
}
if let Some(ref calls) = self.tool_calls {
d["tool_calls"] = calls.iter().map(|tc| tc.to_api_dict()).collect();
}
if let Some(ref id) = self.tool_call_id {
d["tool_call_id"] = serde_json::Value::String(id.clone());
}
if let Some(ref name) = self.name {
d["name"] = serde_json::Value::String(name.clone());
}
d
}
pub fn from_api_dict(d: &serde_json::Value) -> Self {
let tool_calls = d["tool_calls"]
.as_array()
.filter(|arr| !arr.is_empty())
.map(|arr| arr.iter().map(ToolCall::from_api_dict).collect());
Self {
role: d["role"].as_str().unwrap_or("user").to_string(),
content: d.get("content").and_then(|c| {
if c.is_null() { None } else { c.as_str().map(String::from) }
}),
tool_calls,
tool_call_id: d.get("tool_call_id").and_then(|v| v.as_str().map(String::from)),
name: d.get("name").and_then(|v| v.as_str().map(String::from)),
}
}
#[allow(dead_code)]
pub fn pretty_print(&self) -> String {
let mut out = format!("[{}]", self.role.to_uppercase());
if let Some(ref content) = self.content {
if content.len() > 2000 {
out.push_str(&content[..2000]);
out.push_str("\n... [truncated]");
} else {
out.push_str(content);
}
}
if let Some(ref calls) = self.tool_calls {
for tc in calls {
out.push_str(&format!("\n -> tool: {}({:?})", tc.name, tc.arguments));
}
}
out
}
}