intelli_shell/ai/
ollama.rs

1use std::fmt::Debug;
2
3use color_eyre::eyre::Context;
4use reqwest::{
5    Client, RequestBuilder, Response,
6    header::{self, HeaderName},
7};
8use schemars::{JsonSchema, Schema};
9use serde::{Deserialize, de::DeserializeOwned};
10use serde_json::{Value as Json, json};
11
12use super::{AiProvider, AiProviderBase};
13use crate::{
14    config::OllamaModelConfig,
15    errors::{Result, UserFacingError},
16};
17
18impl AiProviderBase for OllamaModelConfig {
19    fn provider_name(&self) -> &'static str {
20        "Ollama"
21    }
22
23    fn auth_header(&self, api_key: String) -> (HeaderName, String) {
24        (header::AUTHORIZATION, format!("Bearer {api_key}"))
25    }
26
27    fn api_key_env_var_name(&self) -> &str {
28        &self.api_key_env
29    }
30
31    fn build_request(
32        &self,
33        client: &Client,
34        sys_prompt: &str,
35        user_prompt: &str,
36        json_schema: &Schema,
37    ) -> RequestBuilder {
38        // Request body
39        // https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion
40        let request_body = json!({
41            "model": self.model,
42            "messages": [
43                {
44                    "role": "system",
45                    "content": sys_prompt
46                },
47                {
48                    "role": "user",
49                    "content": user_prompt
50                }
51            ],
52            "format": json_schema,
53            "stream": false
54        });
55
56        tracing::trace!("Request:\n{request_body:#}");
57
58        // Chat url
59        let url = format!("{}/api/chat", self.url);
60
61        // Request
62        client.post(url).json(&request_body)
63    }
64}
65
66impl AiProvider for OllamaModelConfig {
67    async fn parse_response<T>(&self, res: Response) -> Result<T>
68    where
69        T: DeserializeOwned + JsonSchema + Debug,
70    {
71        // Parse successful response
72        let res: Json = res.json().await.wrap_err("Ollama response not a json")?;
73        tracing::trace!("Response:\n{res:#}");
74        let res: OllamaResponse = serde_json::from_value(res).wrap_err("Couldn't parse Ollama response")?;
75
76        // Validate the response content
77        let Some(message) = res.message.content.filter(|c| !c.trim().is_empty()) else {
78            tracing::error!("Ollama returned an empty response");
79            return Err(UserFacingError::AiRequestFailed(String::from("received an empty response")).into());
80        };
81
82        // Parse the message
83        Ok(serde_json::from_str(&message).map_err(|err| {
84            tracing::error!("Couldn't parse API response into the expected format: {err}\nMessage:\n{message}");
85            UserFacingError::AiRequestFailed(String::from("couldn't parse api response into the expected format"))
86        })?)
87    }
88}
89
90#[derive(Debug, Deserialize)]
91struct OllamaResponse {
92    message: OllamaResponseMessage,
93}
94
95#[derive(Debug, Deserialize)]
96struct OllamaResponseMessage {
97    #[serde(default)]
98    content: Option<String>,
99}
100
101#[cfg(test)]
102mod tests {
103    use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
104
105    use super::*;
106    use crate::{ai::AiClient, config::AiModelConfig};
107
108    #[tokio::test]
109    #[ignore] // Real API calls require a running ollama server
110    async fn test_ollama_api() -> Result<()> {
111        tracing_subscriber::registry()
112            .with(tracing_subscriber::fmt::layer().compact())
113            .init();
114        let config = AiModelConfig::Ollama(OllamaModelConfig {
115            model: "gemma3:1b".into(),
116            url: "http://localhost:11434".into(),
117            api_key_env: "OLLAMA_API_KEY".into(),
118        });
119        let client = AiClient::new("test", &config, "", None)?;
120        let res = client
121            .generate_command_suggestions(
122                "you're a cli expert, that will proide command suggestions based on what the user want to do",
123                "undo last n amount of commits",
124            )
125            .await?;
126        tracing::info!("Suggestions:");
127        for command in res.suggestions {
128            tracing::info!("  # {}", command.description);
129            tracing::info!("  {}", command.command);
130        }
131        Ok(())
132    }
133}