use crate::client::{ModelClient, handle_error_response};
use crate::error::{OllamaError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio_stream::{Stream, StreamExt};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub format: Option<Format>,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub think: Option<bool>,
}
impl Default for ChatRequest {
fn default() -> Self {
Self {
model: "llama3".to_string(),
messages: vec![],
stream: false,
format: None,
options: None,
keep_alive: None,
tools: None,
think: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Format {
Json,
Schema(serde_json::Value),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub images: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub function: ToolCallFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallFunction {
pub name: String,
pub arguments: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatResponse {
pub model: String,
pub created_at: String,
pub message: Message,
pub done: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub done_reason: Option<String>,
#[serde(default)]
pub total_duration: u64,
#[serde(default)]
pub load_duration: u64,
#[serde(default)]
pub prompt_eval_count: u32,
#[serde(default)]
pub prompt_eval_duration: u64,
#[serde(default)]
pub eval_count: u32,
#[serde(default)]
pub eval_duration: u64,
}
impl ModelClient {
pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
let url = self
.base_url
.join("api/chat")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_response(response, Some(&request.model)).await
}
pub async fn chat_stream(
&self,
request: ChatRequest,
) -> Result<impl Stream<Item = Result<ChatResponse>> + '_> {
let url = self
.base_url
.join("api/chat")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
if !response.status().is_success() {
return Err(handle_error_response(response, Some(&request.model)).await);
}
let stream = response.bytes_stream().map(|result| {
let bytes = result.map_err(OllamaError::RequestError)?;
let line = std::str::from_utf8(&bytes)
.map_err(|_| OllamaError::StreamError("Invalid UTF-8 in response".to_string()))?
.trim_ascii();
if line.is_empty() {
return Err(OllamaError::StreamError("Empty line in stream".to_string()));
}
serde_json::from_str::<ChatResponse>(line)
.map_err(OllamaError::JsonError)
.map(Ok)
.unwrap()
});
Ok(stream)
}
}