use crate::mistral_v3::config::MistralV3Config;
use crate::mistral_v3::model::MistralV3ForCausalLM;
use std::collections::HashMap;
use thiserror::Error;
use trustformers_core::errors::Result as TFResult;
#[derive(Debug, Error)]
pub enum MistralV3Error {
#[error("configuration error: {0}")]
Config(String),
#[error("no tool calls found in model output")]
NoToolCalls,
#[error("JSON parse error: {0}")]
JsonParse(String),
#[error("unknown tool: {0}")]
UnknownTool(String),
#[error("missing required parameter: {0}")]
MissingRequiredParam(String),
#[error("tensor operation error: {0}")]
TensorOp(String),
}
#[derive(Debug, Clone)]
pub struct ToolUseTokens {
pub tool_call_start: String,
pub tool_call_end: String,
pub tool_results_start: String,
pub tool_results_end: String,
pub available_tools_start: String,
pub available_tools_end: String,
}
impl Default for ToolUseTokens {
fn default() -> Self {
Self {
tool_call_start: "[TOOL_CALLS]".to_string(),
tool_call_end: "[/TOOL_CALLS]".to_string(),
tool_results_start: "[TOOL_RESULTS]".to_string(),
tool_results_end: "[/TOOL_RESULTS]".to_string(),
available_tools_start: "[AVAILABLE_TOOLS]".to_string(),
available_tools_end: "[/AVAILABLE_TOOLS]".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct ToolParameter {
pub name: String,
pub param_type: String,
pub description: String,
pub required: bool,
}
#[derive(Debug, Clone)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: Vec<ToolParameter>,
}
impl ToolDefinition {
pub fn required_params(&self) -> Vec<&str> {
self.parameters.iter().filter(|p| p.required).map(|p| p.name.as_str()).collect()
}
}
#[derive(Debug, Clone)]
pub struct ToolCallRequest {
pub tool_name: String,
pub arguments: HashMap<String, String>,
pub call_id: String,
}
pub fn format_tool_call_prompt(tools: &[ToolDefinition], user_message: &str) -> String {
let tokens = ToolUseTokens::default();
let tool_jsons: Vec<serde_json::Value> = tools
.iter()
.map(|tool| {
let mut properties = serde_json::Map::new();
let mut required_names: Vec<String> = Vec::new();
for param in &tool.parameters {
let mut prop = serde_json::Map::new();
prop.insert("type".to_string(), serde_json::json!(param.param_type));
prop.insert(
"description".to_string(),
serde_json::json!(param.description),
);
properties.insert(param.name.clone(), serde_json::Value::Object(prop));
if param.required {
required_names.push(param.name.clone());
}
}
serde_json::json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": required_names,
}
}
})
})
.collect();
let tools_json = serde_json::to_string(&tool_jsons).unwrap_or_else(|_| "[]".to_string());
format!(
"{} {}{}\n\n{}",
tokens.available_tools_start, tools_json, tokens.available_tools_end, user_message
)
}
pub fn parse_tool_call_response(
response: &str,
tokens: &ToolUseTokens,
) -> Option<Vec<ToolCallRequest>> {
let start_pos = response.find(&tokens.tool_call_start)?;
let after_start = &response[start_pos + tokens.tool_call_start.len()..];
let content = if let Some(end_pos) = after_start.find(&tokens.tool_call_end) {
&after_start[..end_pos]
} else {
after_start
}
.trim();
let json_start = content.find('[')?;
let json_candidate = &content[json_start..];
let json_str = balanced_json_array(json_candidate)?;
let raw: Vec<serde_json::Value> = serde_json::from_str(json_str).ok()?;
if raw.is_empty() {
return None;
}
let mut calls = Vec::with_capacity(raw.len());
for item in raw {
let tool_name = item.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string();
let call_id = item.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string();
let arguments_val = item.get("arguments").cloned().unwrap_or(serde_json::json!({}));
let arguments: HashMap<String, String> =
if let serde_json::Value::Object(map) = arguments_val {
map.into_iter()
.map(|(k, v)| {
let val_str = match &v {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
(k, val_str)
})
.collect()
} else {
HashMap::new()
};
if !tool_name.is_empty() {
calls.push(ToolCallRequest {
tool_name,
arguments,
call_id,
});
}
}
if calls.is_empty() {
None
} else {
Some(calls)
}
}
pub struct MistralV3WithTools {
inner: MistralV3ForCausalLM,
tools: Vec<ToolDefinition>,
tokens: ToolUseTokens,
}
impl MistralV3WithTools {
pub fn new(inner: MistralV3ForCausalLM, tools: Vec<ToolDefinition>) -> Self {
Self {
inner,
tools,
tokens: ToolUseTokens::default(),
}
}
pub fn config(&self) -> &MistralV3Config {
self.inner.config()
}
pub fn format_prompt(&self, user_message: &str) -> String {
format_tool_call_prompt(&self.tools, user_message)
}
pub fn parse_response(&self, response: &str) -> Option<Vec<ToolCallRequest>> {
parse_tool_call_response(response, &self.tokens)
}
pub fn forward(&self, input_ids: Vec<u32>) -> TFResult<trustformers_core::tensor::Tensor> {
self.inner.forward(input_ids)
}
}
fn balanced_json_array(s: &str) -> Option<&str> {
let mut depth = 0_i32;
let mut in_string = false;
let mut escape_next = false;
let mut end_idx = None;
for (byte_idx, ch) in s.char_indices() {
if escape_next {
escape_next = false;
continue;
}
if in_string {
match ch {
'\\' => escape_next = true,
'"' => in_string = false,
_ => {},
}
continue;
}
match ch {
'"' => in_string = true,
'[' => depth += 1,
']' => {
depth -= 1;
if depth == 0 {
end_idx = Some(byte_idx + ch.len_utf8());
break;
}
},
_ => {},
}
}
end_idx.map(|end| &s[..end])
}