Skip to main content

git_cli/
ollama.rs

1use futures_util::StreamExt;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6#[derive(Serialize)]
7struct ChatRequest {
8    model: String,
9    messages: Vec<ChatMessage>,
10    stream: bool,
11    options: GenerateOptions,
12    keep_alive: String,
13}
14
15#[derive(Serialize)]
16struct ChatMessage {
17    role: String,
18    content: String,
19}
20
21#[derive(Serialize)]
22struct GenerateOptions {
23    num_predict: u32,
24    temperature: f32,
25    stop: Vec<String>,
26}
27
28#[derive(Deserialize)]
29struct ChatChunk {
30    message: Option<ChunkMessage>,
31    done: bool,
32    #[serde(flatten)]
33    _extra: std::collections::HashMap<String, Value>,
34}
35
36#[derive(Deserialize)]
37struct ChunkMessage {
38    content: String,
39}
40
41pub async fn generate(
42    endpoint: &str,
43    model: &str,
44    system_prompt: &str,
45    user_prompt: &str,
46    keep_alive: &str,
47) -> Result<String, String> {
48    let url = format!("{endpoint}/api/chat");
49    let client = Client::new();
50
51    let request = ChatRequest {
52        model: model.to_string(),
53        messages: vec![
54            ChatMessage {
55                role: "system".to_string(),
56                content: system_prompt.to_string(),
57            },
58            ChatMessage {
59                role: "user".to_string(),
60                content: user_prompt.to_string(),
61            },
62        ],
63        stream: true,
64        options: GenerateOptions {
65            num_predict: 512,
66            temperature: 0.1,
67            stop: vec!["\n\n\n".to_string()],
68        },
69        keep_alive: keep_alive.to_string(),
70    };
71
72    let response = client
73        .post(&url)
74        .json(&request)
75        .send()
76        .await
77        .map_err(|e| format!("Failed to connect to Ollama at {url}: {e}"))?;
78
79    if !response.status().is_success() {
80        let status = response.status();
81        let body = response.text().await.unwrap_or_default();
82        return Err(format!("Ollama returned {status}: {body}"));
83    }
84
85    let mut full_response = String::new();
86    let mut stream = response.bytes_stream();
87
88    while let Some(chunk_result) = stream.next().await {
89        let bytes = chunk_result.map_err(|e| format!("Stream error: {e}"))?;
90        let text = String::from_utf8_lossy(&bytes);
91
92        for line in text.lines() {
93            if line.trim().is_empty() {
94                continue;
95            }
96            match serde_json::from_str::<ChatChunk>(line) {
97                Ok(chunk) => {
98                    if let Some(msg) = &chunk.message {
99                        eprint!("{}", msg.content);
100                        full_response.push_str(&msg.content);
101                    }
102                    if chunk.done {
103                        eprintln!();
104                        return Ok(full_response);
105                    }
106                }
107                Err(_) => continue,
108            }
109        }
110    }
111
112    Ok(full_response)
113}