lemon_llm/
ollama.rs

1use serde_json::json;
2
3use crate::{GenerateError, LlmBackend};
4
5const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
6
7pub struct OllamaBackend {
8    pub model: OllamaModel,
9    pub url: String,
10}
11
12impl Default for OllamaBackend {
13    fn default() -> Self {
14        Self {
15            model: OllamaModel::default(),
16            url: DEFAULT_OLLAMA_URL.to_string(),
17        }
18    }
19}
20
21#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
22pub enum OllamaModel {
23    Llama2,
24    Llama2Uncensored,
25    #[default]
26    Mistral7B,
27}
28
29impl OllamaModel {
30    pub fn as_str(&self) -> &str {
31        match self {
32            Self::Llama2 => "llama2",
33            Self::Llama2Uncensored => "llama2-uncensored",
34            Self::Mistral7B => "mistral",
35        }
36    }
37}
38
39impl LlmBackend for OllamaBackend {
40    async fn generate(&self, prompt: &str) -> Result<String, GenerateError> {
41        // Pull the model if it's not already downloaded.
42        reqwest::Client::new()
43            .post(format!("{}/api/pull", self.url))
44            .json(&json!({
45                "name": self.model.as_str(),
46            }))
47            .send()
48            .await
49            .map_err(|e| GenerateError::BackendError(e.to_string()))?;
50
51        // Run the model.
52        let response = reqwest::Client::new()
53            .post(format!("{}/api/generate", self.url))
54            .json(&json!({
55                "model": self.model.as_str(),
56                "prompt": prompt,
57            }))
58            .send()
59            .await
60            .map_err(|e| GenerateError::BackendError(e.to_string()))?;
61
62        let text = response
63            .text()
64            .await
65            .map_err(|e| GenerateError::BackendError(e.to_string()))?;
66
67        Ok(text
68            .lines()
69            .map(|line| {
70                // Extract the text from the JSON response.
71                let json: serde_json::Value = serde_json::from_str(line).unwrap();
72                json["response"].as_str().unwrap_or_default().to_string()
73            })
74            .collect::<String>()
75            .trim()
76            .to_string())
77    }
78}