Skip to main content

oai_sdk/
generate.rs

1// Copyright 2026 Cloudflavor GmbH
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6
7// http://www.apache.org/licenses/LICENSE-2.0
8
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::chat::Format;
16use crate::client::ModelClient;
17use crate::client::handle_error_response;
18use crate::client::json_lines_stream;
19use crate::error::{OllamaError, Result};
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22use tokio_stream::Stream;
23
24/// Request for text generation.
25#[derive(Debug, Clone, Serialize, Deserialize, Default)]
26pub struct GenerateRequest {
27    pub model: String,
28    pub prompt: String,
29    #[serde(default)]
30    pub stream: bool,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub suffix: Option<String>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub images: Option<Vec<String>>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub format: Option<Format>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub options: Option<HashMap<String, serde_json::Value>>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub system: Option<String>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub template: Option<String>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub raw: Option<bool>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub keep_alive: Option<String>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub context: Option<Vec<u32>>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub think: Option<bool>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub width: Option<u32>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub height: Option<u32>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub steps: Option<u32>,
57}
58
59/// Response for text generation.
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct GenerateResponse {
62    pub model: String,
63    pub created_at: String,
64    pub response: String,
65    pub done: bool,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub done_reason: Option<String>,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub context: Option<Vec<u32>>,
70    #[serde(default)]
71    pub total_duration: u64,
72    #[serde(default)]
73    pub load_duration: u64,
74    #[serde(default)]
75    pub prompt_eval_count: u32,
76    #[serde(default)]
77    pub prompt_eval_duration: u64,
78    #[serde(default)]
79    pub eval_count: u32,
80    #[serde(default)]
81    pub eval_duration: u64,
82}
83
84impl ModelClient {
85    /// Generate text from a prompt.
86    pub async fn generate(&self, request: GenerateRequest) -> Result<GenerateResponse> {
87        let url = self
88            .base_url
89            .join("api/generate")
90            .map_err(OllamaError::UrlError)?;
91        let response = self
92            .client
93            .post(url)
94            .json(&request)
95            .send()
96            .await
97            .map_err(OllamaError::RequestError)?;
98
99        self.handle_response(response, Some(&request.model)).await
100    }
101
102    /// Generate text from a prompt with streaming.
103    pub async fn generate_stream(
104        &self,
105        request: GenerateRequest,
106    ) -> Result<impl Stream<Item = Result<GenerateResponse>> + '_> {
107        let url = self
108            .base_url
109            .join("api/generate")
110            .map_err(OllamaError::UrlError)?;
111        let response = self
112            .client
113            .post(url)
114            .json(&request)
115            .send()
116            .await
117            .map_err(OllamaError::RequestError)?;
118
119        if !response.status().is_success() {
120            return Err(handle_error_response(response, Some(&request.model)).await);
121        }
122
123        Ok(json_lines_stream(response))
124    }
125}