use crate::client::ModelClient;
use crate::client::handle_error_response;
use crate::client::json_lines_stream;
use crate::error::{OllamaError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio_stream::Stream;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub format: Option<Format>,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub think: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Format {
Json,
Schema(serde_json::Value),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub images: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking: Option<String>,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: content.into(),
images: None,
tool_calls: None,
tool_name: None,
thinking: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: content.into(),
images: None,
tool_calls: None,
tool_name: None,
thinking: None,
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: content.into(),
images: None,
tool_calls: None,
tool_name: None,
thinking: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub function: ToolCallFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallFunction {
pub name: String,
pub arguments: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatResponse {
pub model: String,
pub created_at: String,
pub message: Message,
pub done: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub done_reason: Option<String>,
#[serde(default)]
pub total_duration: u64,
#[serde(default)]
pub load_duration: u64,
#[serde(default)]
pub prompt_eval_count: u32,
#[serde(default)]
pub prompt_eval_duration: u64,
#[serde(default)]
pub eval_count: u32,
#[serde(default)]
pub eval_duration: u64,
}
impl ModelClient {
pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
let url = self
.base_url
.join("api/chat")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_response(response, Some(&request.model)).await
}
pub async fn chat_stream(
&self,
request: ChatRequest,
) -> Result<impl Stream<Item = Result<ChatResponse>> + '_> {
let url = self
.base_url
.join("api/chat")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
if !response.status().is_success() {
return Err(handle_error_response(response, Some(&request.model)).await);
}
Ok(json_lines_stream(response))
}
}