1use crate::chat::Format;
16use crate::client::ModelClient;
17use crate::client::handle_error_response;
18use crate::client::json_lines_stream;
19use crate::error::{OllamaError, Result};
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22use tokio_stream::Stream;
23
24#[derive(Debug, Clone, Serialize, Deserialize, Default)]
26pub struct GenerateRequest {
27 pub model: String,
28 pub prompt: String,
29 #[serde(default)]
30 pub stream: bool,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub suffix: Option<String>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub images: Option<Vec<String>>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub format: Option<Format>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub options: Option<HashMap<String, serde_json::Value>>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub system: Option<String>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub template: Option<String>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub raw: Option<bool>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub keep_alive: Option<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub context: Option<Vec<u32>>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub think: Option<bool>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub width: Option<u32>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub height: Option<u32>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub steps: Option<u32>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct GenerateResponse {
62 pub model: String,
63 pub created_at: String,
64 pub response: String,
65 pub done: bool,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 pub done_reason: Option<String>,
68 #[serde(skip_serializing_if = "Option::is_none")]
69 pub context: Option<Vec<u32>>,
70 #[serde(default)]
71 pub total_duration: u64,
72 #[serde(default)]
73 pub load_duration: u64,
74 #[serde(default)]
75 pub prompt_eval_count: u32,
76 #[serde(default)]
77 pub prompt_eval_duration: u64,
78 #[serde(default)]
79 pub eval_count: u32,
80 #[serde(default)]
81 pub eval_duration: u64,
82}
83
84impl ModelClient {
85 pub async fn generate(&self, request: GenerateRequest) -> Result<GenerateResponse> {
87 let url = self
88 .base_url
89 .join("api/generate")
90 .map_err(OllamaError::UrlError)?;
91 let response = self
92 .client
93 .post(url)
94 .json(&request)
95 .send()
96 .await
97 .map_err(OllamaError::RequestError)?;
98
99 self.handle_response(response, Some(&request.model)).await
100 }
101
102 pub async fn generate_stream(
104 &self,
105 request: GenerateRequest,
106 ) -> Result<impl Stream<Item = Result<GenerateResponse>> + '_> {
107 let url = self
108 .base_url
109 .join("api/generate")
110 .map_err(OllamaError::UrlError)?;
111 let response = self
112 .client
113 .post(url)
114 .json(&request)
115 .send()
116 .await
117 .map_err(OllamaError::RequestError)?;
118
119 if !response.status().is_success() {
120 return Err(handle_error_response(response, Some(&request.model)).await);
121 }
122
123 Ok(json_lines_stream(response))
124 }
125}