1use crate::client::ModelClient;
16use crate::client::handle_error_response;
17use crate::client::json_lines_stream;
18use crate::error::{OllamaError, Result};
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use tokio_stream::Stream;
22
23#[derive(Debug, Clone, Serialize, Deserialize, Default)]
25pub struct ChatRequest {
26 pub model: String,
27 pub messages: Vec<Message>,
28 #[serde(default)]
29 pub stream: bool,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub format: Option<Format>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub options: Option<HashMap<String, serde_json::Value>>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub keep_alive: Option<String>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub tools: Option<Vec<Tool>>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub think: Option<bool>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44#[serde(untagged)]
45pub enum Format {
46 Json,
47 Schema(serde_json::Value),
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct Message {
53 pub role: String,
54 pub content: String,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub images: Option<Vec<String>>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub tool_calls: Option<Vec<ToolCall>>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub tool_name: Option<String>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub thinking: Option<String>,
63}
64
65impl Message {
66 pub fn user(content: impl Into<String>) -> Self {
67 Self {
68 role: "user".to_string(),
69 content: content.into(),
70 images: None,
71 tool_calls: None,
72 tool_name: None,
73 thinking: None,
74 }
75 }
76
77 pub fn assistant(content: impl Into<String>) -> Self {
78 Self {
79 role: "assistant".to_string(),
80 content: content.into(),
81 images: None,
82 tool_calls: None,
83 tool_name: None,
84 thinking: None,
85 }
86 }
87
88 pub fn system(content: impl Into<String>) -> Self {
89 Self {
90 role: "system".to_string(),
91 content: content.into(),
92 images: None,
93 tool_calls: None,
94 tool_name: None,
95 thinking: None,
96 }
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct Tool {
103 #[serde(rename = "type")]
104 pub tool_type: String,
105 pub function: ToolFunction,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ToolFunction {
111 pub name: String,
112 pub description: String,
113 pub parameters: serde_json::Value,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct ToolCall {
119 pub function: ToolCallFunction,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ToolCallFunction {
125 pub name: String,
126 pub arguments: HashMap<String, serde_json::Value>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ChatResponse {
132 pub model: String,
133 pub created_at: String,
134 pub message: Message,
135 pub done: bool,
136 #[serde(skip_serializing_if = "Option::is_none")]
137 pub done_reason: Option<String>,
138 #[serde(default)]
139 pub total_duration: u64,
140 #[serde(default)]
141 pub load_duration: u64,
142 #[serde(default)]
143 pub prompt_eval_count: u32,
144 #[serde(default)]
145 pub prompt_eval_duration: u64,
146 #[serde(default)]
147 pub eval_count: u32,
148 #[serde(default)]
149 pub eval_duration: u64,
150}
151
152impl ModelClient {
153 pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
155 let url = self
156 .base_url
157 .join("api/chat")
158 .map_err(OllamaError::UrlError)?;
159 let response = self
160 .client
161 .post(url)
162 .json(&request)
163 .send()
164 .await
165 .map_err(OllamaError::RequestError)?;
166
167 self.handle_response(response, Some(&request.model)).await
168 }
169
170 pub async fn chat_stream(
172 &self,
173 request: ChatRequest,
174 ) -> Result<impl Stream<Item = Result<ChatResponse>> + '_> {
175 let url = self
176 .base_url
177 .join("api/chat")
178 .map_err(OllamaError::UrlError)?;
179 let response = self
180 .client
181 .post(url)
182 .json(&request)
183 .send()
184 .await
185 .map_err(OllamaError::RequestError)?;
186
187 if !response.status().is_success() {
188 return Err(handle_error_response(response, Some(&request.model)).await);
189 }
190
191 Ok(json_lines_stream(response))
192 }
193}