Skip to main content

oai_sdk/
chat.rs

1// Copyright 2026 Cloudflavor GmbH
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6
7// http://www.apache.org/licenses/LICENSE-2.0
8
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::client::ModelClient;
16use crate::client::handle_error_response;
17use crate::client::json_lines_stream;
18use crate::error::{OllamaError, Result};
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use tokio_stream::Stream;
22
23/// Request for chat completion.
24#[derive(Debug, Clone, Serialize, Deserialize, Default)]
25pub struct ChatRequest {
26    pub model: String,
27    pub messages: Vec<Message>,
28    #[serde(default)]
29    pub stream: bool,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub format: Option<Format>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub options: Option<HashMap<String, serde_json::Value>>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub keep_alive: Option<String>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub tools: Option<Vec<Tool>>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub think: Option<bool>,
40}
41
42/// Format for the response.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44#[serde(untagged)]
45pub enum Format {
46    Json,
47    Schema(serde_json::Value),
48}
49
50/// A message in a chat.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct Message {
53    pub role: String,
54    pub content: String,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub images: Option<Vec<String>>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub tool_calls: Option<Vec<ToolCall>>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub tool_name: Option<String>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub thinking: Option<String>,
63}
64
65impl Message {
66    pub fn user(content: impl Into<String>) -> Self {
67        Self {
68            role: "user".to_string(),
69            content: content.into(),
70            images: None,
71            tool_calls: None,
72            tool_name: None,
73            thinking: None,
74        }
75    }
76
77    pub fn assistant(content: impl Into<String>) -> Self {
78        Self {
79            role: "assistant".to_string(),
80            content: content.into(),
81            images: None,
82            tool_calls: None,
83            tool_name: None,
84            thinking: None,
85        }
86    }
87
88    pub fn system(content: impl Into<String>) -> Self {
89        Self {
90            role: "system".to_string(),
91            content: content.into(),
92            images: None,
93            tool_calls: None,
94            tool_name: None,
95            thinking: None,
96        }
97    }
98}
99
100/// A tool that can be used by the model.
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct Tool {
103    #[serde(rename = "type")]
104    pub tool_type: String,
105    pub function: ToolFunction,
106}
107
108/// A tool function.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ToolFunction {
111    pub name: String,
112    pub description: String,
113    pub parameters: serde_json::Value,
114}
115
116/// A tool call.
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct ToolCall {
119    pub function: ToolCallFunction,
120}
121
122/// A tool call function.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ToolCallFunction {
125    pub name: String,
126    pub arguments: HashMap<String, serde_json::Value>,
127}
128
129/// Response for chat completion.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ChatResponse {
132    pub model: String,
133    pub created_at: String,
134    pub message: Message,
135    pub done: bool,
136    #[serde(skip_serializing_if = "Option::is_none")]
137    pub done_reason: Option<String>,
138    #[serde(default)]
139    pub total_duration: u64,
140    #[serde(default)]
141    pub load_duration: u64,
142    #[serde(default)]
143    pub prompt_eval_count: u32,
144    #[serde(default)]
145    pub prompt_eval_duration: u64,
146    #[serde(default)]
147    pub eval_count: u32,
148    #[serde(default)]
149    pub eval_duration: u64,
150}
151
152impl ModelClient {
153    /// Generate a chat completion.
154    pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
155        let url = self
156            .base_url
157            .join("api/chat")
158            .map_err(OllamaError::UrlError)?;
159        let response = self
160            .client
161            .post(url)
162            .json(&request)
163            .send()
164            .await
165            .map_err(OllamaError::RequestError)?;
166
167        self.handle_response(response, Some(&request.model)).await
168    }
169
170    /// Generate a streaming chat completion.
171    pub async fn chat_stream(
172        &self,
173        request: ChatRequest,
174    ) -> Result<impl Stream<Item = Result<ChatResponse>> + '_> {
175        let url = self
176            .base_url
177            .join("api/chat")
178            .map_err(OllamaError::UrlError)?;
179        let response = self
180            .client
181            .post(url)
182            .json(&request)
183            .send()
184            .await
185            .map_err(OllamaError::RequestError)?;
186
187        if !response.status().is_success() {
188            return Err(handle_error_response(response, Some(&request.model)).await);
189        }
190
191        Ok(json_lines_stream(response))
192    }
193}