1use serde_json::json;
2
3use crate::{GenerateError, LlmBackend};
4
5const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434";
6
7pub struct OllamaBackend {
8 pub model: OllamaModel,
9 pub url: String,
10}
11
12impl Default for OllamaBackend {
13 fn default() -> Self {
14 Self {
15 model: OllamaModel::default(),
16 url: DEFAULT_OLLAMA_URL.to_string(),
17 }
18 }
19}
20
21#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
22pub enum OllamaModel {
23 Llama2,
24 Llama2Uncensored,
25 #[default]
26 Mistral7B,
27}
28
29impl OllamaModel {
30 pub fn as_str(&self) -> &str {
31 match self {
32 Self::Llama2 => "llama2",
33 Self::Llama2Uncensored => "llama2-uncensored",
34 Self::Mistral7B => "mistral",
35 }
36 }
37}
38
39impl LlmBackend for OllamaBackend {
40 async fn generate(&self, prompt: &str) -> Result<String, GenerateError> {
41 reqwest::Client::new()
43 .post(format!("{}/api/pull", self.url))
44 .json(&json!({
45 "name": self.model.as_str(),
46 }))
47 .send()
48 .await
49 .map_err(|e| GenerateError::BackendError(e.to_string()))?;
50
51 let response = reqwest::Client::new()
53 .post(format!("{}/api/generate", self.url))
54 .json(&json!({
55 "model": self.model.as_str(),
56 "prompt": prompt,
57 }))
58 .send()
59 .await
60 .map_err(|e| GenerateError::BackendError(e.to_string()))?;
61
62 let text = response
63 .text()
64 .await
65 .map_err(|e| GenerateError::BackendError(e.to_string()))?;
66
67 Ok(text
68 .lines()
69 .map(|line| {
70 let json: serde_json::Value = serde_json::from_str(line).unwrap();
72 json["response"].as_str().unwrap_or_default().to_string()
73 })
74 .collect::<String>()
75 .trim()
76 .to_string())
77 }
78}