use super::types::*;
use crate::error::LlmError;
use crate::types::{ChatMessage, Tool, ToolCall};
use crate::utils::http_headers::ProviderHeaders;
use reqwest::header::HeaderMap;
use std::collections::HashMap;
pub fn build_headers(additional_headers: &HashMap<String, String>) -> Result<HeaderMap, LlmError> {
ProviderHeaders::ollama(additional_headers)
}
pub fn convert_chat_message(message: &ChatMessage) -> OllamaChatMessage {
let role_str = match message.role {
crate::types::MessageRole::System => "system",
crate::types::MessageRole::User => "user",
crate::types::MessageRole::Assistant => "assistant",
crate::types::MessageRole::Developer => "system", crate::types::MessageRole::Tool => "tool",
}
.to_string();
let content_str = match &message.content {
crate::types::MessageContent::Text(text) => text.clone(),
crate::types::MessageContent::MultiModal(parts) => {
parts
.iter()
.filter_map(|part| {
if let crate::types::ContentPart::Text { text } = part {
Some(text.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join(" ")
}
};
let mut ollama_message = OllamaChatMessage {
role: role_str,
content: content_str,
images: None,
tool_calls: None,
thinking: None,
};
if let crate::types::MessageContent::MultiModal(parts) = &message.content {
let images: Vec<String> = parts
.iter()
.filter_map(|part| {
if let crate::types::ContentPart::Image { image_url, .. } = part {
Some(image_url.clone())
} else {
None
}
})
.collect();
if !images.is_empty() {
ollama_message.images = Some(images);
}
}
if let Some(tool_calls) = &message.tool_calls {
ollama_message.tool_calls = Some(tool_calls.iter().map(convert_tool_call).collect());
}
ollama_message
}
pub fn convert_tool(tool: &Tool) -> OllamaTool {
OllamaTool {
tool_type: "function".to_string(),
function: OllamaFunction {
name: tool.function.name.clone(),
description: tool.function.description.clone(),
parameters: tool.function.parameters.clone(),
},
}
}
pub fn convert_tool_call(tool_call: &ToolCall) -> OllamaToolCall {
OllamaToolCall {
function: OllamaFunctionCall {
name: tool_call
.function
.as_ref()
.map(|f| f.name.clone())
.unwrap_or_default(),
arguments: tool_call
.function
.as_ref()
.map(|f| {
serde_json::from_str(&f.arguments)
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()))
})
.unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
},
}
}
pub fn convert_from_ollama_message(message: &OllamaChatMessage) -> ChatMessage {
let role = match message.role.as_str() {
"system" => crate::types::MessageRole::System,
"user" => crate::types::MessageRole::User,
"assistant" => crate::types::MessageRole::Assistant,
"tool" => crate::types::MessageRole::Tool,
_ => crate::types::MessageRole::Assistant, };
let mut content = crate::types::MessageContent::Text(message.content.clone());
if let Some(images) = &message.images {
let mut parts = vec![crate::types::ContentPart::Text {
text: message.content.clone(),
}];
for image_url in images {
parts.push(crate::types::ContentPart::Image {
image_url: image_url.clone(),
detail: None,
});
}
content = crate::types::MessageContent::MultiModal(parts);
}
let mut chat_message = ChatMessage {
role,
content,
metadata: crate::types::MessageMetadata::default(),
tool_calls: None,
tool_call_id: None,
};
if let Some(tool_calls) = &message.tool_calls {
chat_message.tool_calls = Some(
tool_calls
.iter()
.map(convert_from_ollama_tool_call)
.collect(),
);
}
chat_message
}
pub fn convert_from_ollama_tool_call(tool_call: &OllamaToolCall) -> ToolCall {
ToolCall {
id: format!("call_{}", chrono::Utc::now().timestamp_millis()), r#type: "function".to_string(),
function: Some(crate::types::FunctionCall {
name: tool_call.function.name.clone(),
arguments: tool_call.function.arguments.to_string(),
}),
}
}
pub fn parse_streaming_line(line: &str) -> Result<Option<serde_json::Value>, LlmError> {
let line = line.trim();
if line.is_empty() || line.starts_with(':') {
return Ok(None);
}
let json_str = if let Some(stripped) = line.strip_prefix("data: ") {
stripped
} else {
line
};
if json_str == "[DONE]" {
return Ok(None);
}
serde_json::from_str(json_str)
.map(Some)
.map_err(|e| LlmError::ParseError(format!("Failed to parse streaming response: {e}")))
}
pub fn extract_model_name(model: &str) -> String {
model.to_string()
}
pub fn validate_model_name(model: &str) -> Result<(), LlmError> {
if model.is_empty() {
return Err(LlmError::ConfigurationError(
"Model name cannot be empty".to_string(),
));
}
if model.contains(' ') || model.contains('\n') || model.contains('\t') {
return Err(LlmError::ConfigurationError(
"Model name contains invalid characters".to_string(),
));
}
Ok(())
}
pub fn build_model_options(
temperature: Option<f32>,
max_tokens: Option<u32>,
top_p: Option<f32>,
frequency_penalty: Option<f32>,
presence_penalty: Option<f32>,
additional_options: Option<&HashMap<String, serde_json::Value>>,
) -> HashMap<String, serde_json::Value> {
let mut options = HashMap::new();
if let Some(temp) = temperature {
options.insert(
"temperature".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(temp as f64)
.unwrap_or_else(|| serde_json::Number::from(0)),
),
);
}
if let Some(max_tokens) = max_tokens {
options.insert(
"num_predict".to_string(),
serde_json::Value::Number(serde_json::Number::from(max_tokens)),
);
}
if let Some(top_p) = top_p {
options.insert(
"top_p".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(top_p as f64)
.unwrap_or_else(|| serde_json::Number::from(0)),
),
);
}
if let Some(freq_penalty) = frequency_penalty {
options.insert(
"frequency_penalty".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(freq_penalty as f64)
.unwrap_or_else(|| serde_json::Number::from(0)),
),
);
}
if let Some(pres_penalty) = presence_penalty {
options.insert(
"presence_penalty".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(pres_penalty as f64)
.unwrap_or_else(|| serde_json::Number::from(0)),
),
);
}
if let Some(additional) = additional_options {
for (key, value) in additional {
options.insert(key.clone(), value.clone());
}
}
options
}
pub fn calculate_tokens_per_second(
eval_count: Option<u32>,
eval_duration: Option<u64>,
) -> Option<f64> {
match (eval_count, eval_duration) {
(Some(count), Some(duration)) if duration > 0 => {
let duration_seconds = duration as f64 / 1_000_000_000.0;
Some(count as f64 / duration_seconds)
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use reqwest::header::{CONTENT_TYPE, USER_AGENT};
#[test]
fn test_build_headers() {
let additional = HashMap::new();
let headers = build_headers(&additional).unwrap();
assert!(headers.contains_key(CONTENT_TYPE));
assert!(headers.contains_key(USER_AGENT));
}
#[test]
fn test_convert_chat_message() {
let message = ChatMessage {
role: crate::types::MessageRole::User,
content: crate::types::MessageContent::MultiModal(vec![
crate::types::ContentPart::Text {
text: "Hello".to_string(),
},
crate::types::ContentPart::Image {
image_url: "image1".to_string(),
detail: None,
},
]),
metadata: crate::types::MessageMetadata::default(),
tool_calls: None,
tool_call_id: None,
};
let ollama_message = convert_chat_message(&message);
assert_eq!(ollama_message.role, "user");
assert_eq!(ollama_message.content, "Hello");
assert_eq!(ollama_message.images, Some(vec!["image1".to_string()]));
}
#[test]
fn test_validate_model_name() {
assert!(validate_model_name("llama3.2").is_ok());
assert!(validate_model_name("llama3.2:latest").is_ok());
assert!(validate_model_name("").is_err());
assert!(validate_model_name("model with spaces").is_err());
}
#[test]
fn test_calculate_tokens_per_second() {
assert_eq!(
calculate_tokens_per_second(Some(100), Some(1_000_000_000)),
Some(100.0)
);
assert_eq!(
calculate_tokens_per_second(Some(50), Some(500_000_000)),
Some(100.0)
);
assert_eq!(calculate_tokens_per_second(None, Some(1_000_000_000)), None);
assert_eq!(calculate_tokens_per_second(Some(100), None), None);
assert_eq!(calculate_tokens_per_second(Some(100), Some(0)), None);
}
}