1use futures_util::StreamExt;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6#[derive(Serialize)]
7struct ChatRequest {
8 model: String,
9 messages: Vec<ChatMessage>,
10 stream: bool,
11 options: GenerateOptions,
12 keep_alive: String,
13}
14
15#[derive(Serialize)]
16struct ChatMessage {
17 role: String,
18 content: String,
19}
20
21#[derive(Serialize)]
22struct GenerateOptions {
23 num_predict: u32,
24 temperature: f32,
25 stop: Vec<String>,
26}
27
28#[derive(Deserialize)]
29struct ChatChunk {
30 message: Option<ChunkMessage>,
31 done: bool,
32 #[serde(flatten)]
33 _extra: std::collections::HashMap<String, Value>,
34}
35
36#[derive(Deserialize)]
37struct ChunkMessage {
38 content: String,
39}
40
41pub async fn generate(
42 endpoint: &str,
43 model: &str,
44 system_prompt: &str,
45 user_prompt: &str,
46 keep_alive: &str,
47) -> Result<String, String> {
48 let url = format!("{endpoint}/api/chat");
49 let client = Client::new();
50
51 let request = ChatRequest {
52 model: model.to_string(),
53 messages: vec![
54 ChatMessage {
55 role: "system".to_string(),
56 content: system_prompt.to_string(),
57 },
58 ChatMessage {
59 role: "user".to_string(),
60 content: user_prompt.to_string(),
61 },
62 ],
63 stream: true,
64 options: GenerateOptions {
65 num_predict: 512,
66 temperature: 0.1,
67 stop: vec!["\n\n\n".to_string()],
68 },
69 keep_alive: keep_alive.to_string(),
70 };
71
72 let response = client
73 .post(&url)
74 .json(&request)
75 .send()
76 .await
77 .map_err(|e| format!("Failed to connect to Ollama at {url}: {e}"))?;
78
79 if !response.status().is_success() {
80 let status = response.status();
81 let body = response.text().await.unwrap_or_default();
82 return Err(format!("Ollama returned {status}: {body}"));
83 }
84
85 let mut full_response = String::new();
86 let mut stream = response.bytes_stream();
87
88 while let Some(chunk_result) = stream.next().await {
89 let bytes = chunk_result.map_err(|e| format!("Stream error: {e}"))?;
90 let text = String::from_utf8_lossy(&bytes);
91
92 for line in text.lines() {
93 if line.trim().is_empty() {
94 continue;
95 }
96 match serde_json::from_str::<ChatChunk>(line) {
97 Ok(chunk) => {
98 if let Some(msg) = &chunk.message {
99 eprint!("{}", msg.content);
100 full_response.push_str(&msg.content);
101 }
102 if chunk.done {
103 eprintln!();
104 return Ok(full_response);
105 }
106 }
107 Err(_) => continue,
108 }
109 }
110 }
111
112 Ok(full_response)
113}