use std::collections::HashMap;
use std::pin::Pin;
use futures::stream::{self, Stream, StreamExt as _};
use serde::Serialize;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use swink_agent::{
AgentContext, AgentMessage, AssistantMessageEvent, ModelSpec, StreamFn, StreamOptions,
};
use crate::convert;
use crate::oai_transport::{OaiAdapterShell, oai_send_and_parse};
use crate::openai_compat::{OaiConverter, OaiMessage, build_oai_tools};
const MISTRAL_ID_CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
pub struct MistralStreamFn {
shell: OaiAdapterShell,
}
impl MistralStreamFn {
#[must_use]
pub fn new(base_url: impl Into<String>, api_key: impl Into<String>) -> Self {
Self {
shell: OaiAdapterShell::new("Mistral", base_url, api_key),
}
}
}
impl std::fmt::Debug for MistralStreamFn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.shell.fmt_debug("MistralStreamFn", f)
}
}
impl StreamFn for MistralStreamFn {
fn stream<'a>(
&'a self,
model: &'a ModelSpec,
context: &'a AgentContext,
options: &'a StreamOptions,
cancellation_token: CancellationToken,
) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
Box::pin(mistral_stream(
self,
model,
context,
options,
cancellation_token,
))
}
}
struct MistralIdMap {
harness_to_mistral: HashMap<String, String>,
mistral_to_harness: HashMap<String, String>,
counter: u32,
}
impl MistralIdMap {
fn new() -> Self {
Self {
harness_to_mistral: HashMap::new(),
mistral_to_harness: HashMap::new(),
counter: 0,
}
}
fn remap_to_mistral(&mut self, harness_id: &str) -> String {
if let Some(mid) = self.harness_to_mistral.get(harness_id) {
return mid.clone();
}
let mid = self.generate_mistral_id();
self.harness_to_mistral
.insert(harness_id.to_string(), mid.clone());
self.mistral_to_harness
.insert(mid.clone(), harness_id.to_string());
mid
}
fn remap_to_harness(&mut self, mistral_id: &str) -> String {
if let Some(hid) = self.mistral_to_harness.get(mistral_id) {
return hid.clone();
}
let hid = format!("call_{mistral_id}");
self.mistral_to_harness
.insert(mistral_id.to_string(), hid.clone());
self.harness_to_mistral
.insert(hid.clone(), mistral_id.to_string());
hid
}
fn generate_mistral_id(&mut self) -> String {
let uuid = uuid::Uuid::new_v4();
let bytes = uuid.as_bytes();
let mut id = String::with_capacity(9);
for &b in &bytes[..9] {
id.push(MISTRAL_ID_CHARSET[b as usize % MISTRAL_ID_CHARSET.len()] as char);
}
self.counter += 1;
id
}
}
#[derive(Debug, Serialize)]
struct MistralChatRequest {
model: String,
messages: Vec<OaiMessage>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<crate::openai_compat::OaiTool>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
}
fn mistral_stream<'a>(
mistral: &'a MistralStreamFn,
model: &'a ModelSpec,
context: &'a AgentContext,
options: &'a StreamOptions,
cancellation_token: CancellationToken,
) -> impl Stream<Item = AssistantMessageEvent> + Send + 'a {
stream::once(async move {
let mut id_map = MistralIdMap::new();
let url = mistral.shell.chat_completions_url();
debug!(
%url,
model = %model.model_id,
messages = context.messages.len(),
"sending Mistral request"
);
let messages =
convert_messages_for_mistral(&context.messages, &context.system_prompt, &mut id_map);
let (tools, tool_choice) = build_oai_tools(&context.tools);
let body = MistralChatRequest {
model: model.model_id.clone(),
messages,
stream: true,
temperature: options.temperature,
max_tokens: options.max_tokens,
tools,
tool_choice,
};
let request = mistral.shell.post_json_request(&url, &body, options);
let raw_stream = oai_send_and_parse(
request,
mistral.shell.provider(),
cancellation_token,
options.on_raw_payload.clone(),
|_, _| None,
);
normalize_response_stream(raw_stream, id_map)
})
.flatten()
}
fn convert_messages_for_mistral(
messages: &[AgentMessage],
system_prompt: &str,
id_map: &mut MistralIdMap,
) -> Vec<OaiMessage> {
let raw_messages = convert::convert_messages::<OaiConverter>(messages, system_prompt);
let mut result: Vec<OaiMessage> = Vec::with_capacity(raw_messages.len() + 4);
let mut prev_was_tool = false;
for mut msg in raw_messages {
if prev_was_tool && msg.role == "user" {
result.push(OaiMessage {
role: "assistant".to_string(),
content: Some(String::new()),
tool_calls: None,
tool_call_id: None,
});
}
if msg.role == "assistant"
&& let Some(ref mut tool_calls) = msg.tool_calls
{
for tc in tool_calls.iter_mut() {
tc.id = id_map.remap_to_mistral(&tc.id);
}
}
if msg.role == "tool"
&& let Some(ref id) = msg.tool_call_id
{
msg.tool_call_id = Some(id_map.remap_to_mistral(id));
}
prev_was_tool = msg.role == "tool";
result.push(msg);
}
result
}
fn normalize_response_stream(
raw: impl Stream<Item = AssistantMessageEvent> + Send,
mut id_map: MistralIdMap,
) -> impl Stream<Item = AssistantMessageEvent> + Send {
raw.map(move |event| match event {
AssistantMessageEvent::ToolCallStart {
content_index,
id,
name,
} => {
let harness_id = id_map.remap_to_harness(&id);
AssistantMessageEvent::ToolCallStart {
content_index,
id: harness_id,
name,
}
}
other => other,
})
}
const _: () = {
const fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<MistralStreamFn>();
};