intelli_shell/ai/
ollama.rs1use std::fmt::Debug;
2
3use color_eyre::eyre::Context;
4use reqwest::{
5 Client, RequestBuilder, Response,
6 header::{self, HeaderName},
7};
8use schemars::{JsonSchema, Schema};
9use serde::{Deserialize, de::DeserializeOwned};
10use serde_json::{Value as Json, json};
11
12use super::{AiProvider, AiProviderBase};
13use crate::{
14 config::OllamaModelConfig,
15 errors::{Result, UserFacingError},
16};
17
18impl AiProviderBase for OllamaModelConfig {
19 fn provider_name(&self) -> &'static str {
20 "Ollama"
21 }
22
23 fn auth_header(&self, api_key: String) -> (HeaderName, String) {
24 (header::AUTHORIZATION, format!("Bearer {api_key}"))
25 }
26
27 fn api_key_env_var_name(&self) -> &str {
28 &self.api_key_env
29 }
30
31 fn build_request(
32 &self,
33 client: &Client,
34 sys_prompt: &str,
35 user_prompt: &str,
36 json_schema: &Schema,
37 ) -> RequestBuilder {
38 let request_body = json!({
41 "model": self.model,
42 "messages": [
43 {
44 "role": "system",
45 "content": sys_prompt
46 },
47 {
48 "role": "user",
49 "content": user_prompt
50 }
51 ],
52 "format": json_schema,
53 "stream": false
54 });
55
56 tracing::trace!("Request:\n{request_body:#}");
57
58 let url = format!("{}/api/chat", self.url);
60
61 client.post(url).json(&request_body)
63 }
64}
65
66impl AiProvider for OllamaModelConfig {
67 async fn parse_response<T>(&self, res: Response) -> Result<T>
68 where
69 T: DeserializeOwned + JsonSchema + Debug,
70 {
71 let res: Json = res.json().await.wrap_err("Ollama response not a json")?;
73 tracing::trace!("Response:\n{res:#}");
74 let res: OllamaResponse = serde_json::from_value(res).wrap_err("Couldn't parse Ollama response")?;
75
76 let Some(message) = res.message.content.filter(|c| !c.trim().is_empty()) else {
78 tracing::error!("Ollama returned an empty response");
79 return Err(UserFacingError::AiRequestFailed(String::from("received an empty response")).into());
80 };
81
82 Ok(serde_json::from_str(&message).map_err(|err| {
84 tracing::error!("Couldn't parse API response into the expected format: {err}\nMessage:\n{message}");
85 UserFacingError::AiRequestFailed(String::from("couldn't parse api response into the expected format"))
86 })?)
87 }
88}
89
90#[derive(Debug, Deserialize)]
91struct OllamaResponse {
92 message: OllamaResponseMessage,
93}
94
95#[derive(Debug, Deserialize)]
96struct OllamaResponseMessage {
97 #[serde(default)]
98 content: Option<String>,
99}
100
101#[cfg(test)]
102mod tests {
103 use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
104
105 use super::*;
106 use crate::{ai::AiClient, config::AiModelConfig};
107
108 #[tokio::test]
109 #[ignore] async fn test_ollama_api() -> Result<()> {
111 tracing_subscriber::registry()
112 .with(tracing_subscriber::fmt::layer().compact())
113 .init();
114 let config = AiModelConfig::Ollama(OllamaModelConfig {
115 model: "gemma3:1b".into(),
116 url: "http://localhost:11434".into(),
117 api_key_env: "OLLAMA_API_KEY".into(),
118 });
119 let client = AiClient::new("test", &config, "", None)?;
120 let res = client
121 .generate_command_suggestions(
122 "you're a cli expert, that will proide command suggestions based on what the user want to do",
123 "undo last n amount of commits",
124 )
125 .await?;
126 tracing::info!("Suggestions:");
127 for command in res.suggestions {
128 tracing::info!(" # {}", command.description);
129 tracing::info!(" {}", command.command);
130 }
131 Ok(())
132 }
133}