intelli_shell/ai/
mod.rs

1use std::{env, fmt::Debug, time::Duration};
2
3use color_eyre::eyre::{Context, ContextCompat, eyre};
4use reqwest::{
5    Client, ClientBuilder, RequestBuilder, Response, StatusCode,
6    header::{self, HeaderMap, HeaderName, HeaderValue},
7};
8use schemars::{JsonSchema, Schema, schema_for};
9use serde::{Deserialize, de::DeserializeOwned};
10use tracing::instrument;
11
12use crate::{
13    config::AiModelConfig,
14    errors::{AppError, Result, UserFacingError},
15};
16
17mod anthropic;
18mod gemini;
19mod ollama;
20mod openai;
21
22/// A trait that defines the provider-specific logic for the generic [`AiClient`]
23pub trait AiProviderBase: Send + Sync {
24    /// The name of the provider
25    fn provider_name(&self) -> &'static str;
26
27    /// Returns the header name and value to authenticate the given api key
28    fn auth_header(&self, api_key: String) -> (HeaderName, String);
29
30    /// The name of the environent variable expected to have the api key
31    fn api_key_env_var_name(&self) -> &str;
32
33    /// Build the provider-specific request
34    fn build_request(
35        &self,
36        client: &Client,
37        sys_prompt: &str,
38        user_prompt: &str,
39        json_schema: &Schema,
40    ) -> RequestBuilder;
41}
42
43/// A trait that defines the provider-specific logic for the generic [`AiClient`]
44#[trait_variant::make(Send)]
45pub trait AiProvider: AiProviderBase {
46    /// Parse the provider-specific response
47    async fn parse_response<T>(&self, res: Response) -> Result<T>
48    where
49        Self: Sized,
50        T: DeserializeOwned + JsonSchema + Debug;
51}
52
53/// A generic client to communicate with AI providers
54#[cfg_attr(debug_assertions, derive(Debug))]
55pub struct AiClient<'a> {
56    inner: Client,
57    primary_alias: &'a str,
58    primary: &'a AiModelConfig,
59    fallback_alias: &'a str,
60    fallback: Option<&'a AiModelConfig>,
61}
62impl<'a> AiClient<'a> {
63    /// Creates a new AI client with a primary and an optional fallback model configuration
64    pub fn new(
65        primary_alias: &'a str,
66        primary: &'a AiModelConfig,
67        fallback_alias: &'a str,
68        fallback: Option<&'a AiModelConfig>,
69    ) -> Result<Self> {
70        // Construct the base headers for all requests
71        let mut headers = HeaderMap::new();
72        headers.append(header::CONTENT_TYPE, HeaderValue::from_static("application/json"));
73
74        // Build the reqwest client
75        let inner = ClientBuilder::new()
76            .connect_timeout(Duration::from_secs(2))
77            .timeout(Duration::from_secs(5 * 60))
78            .user_agent("intelli-shell")
79            .default_headers(headers)
80            .build()
81            .wrap_err("Couldn't build AI client")?;
82
83        Ok(AiClient {
84            inner,
85            primary_alias,
86            primary,
87            fallback_alias,
88            fallback,
89        })
90    }
91
92    /// Generate some command suggestions based on the given prompt
93    #[instrument(skip_all)]
94    pub async fn generate_command_suggestions(
95        &self,
96        sys_prompt: &str,
97        user_prompt: &str,
98    ) -> Result<CommandSuggestions> {
99        self.generate_content(sys_prompt, user_prompt).await
100    }
101
102    /// Generate a command fix based on the given prompt
103    #[instrument(skip_all)]
104    pub async fn generate_command_fix(&self, sys_prompt: &str, user_prompt: &str) -> Result<CommandFix> {
105        self.generate_content(sys_prompt, user_prompt).await
106    }
107
108    /// The inner logic to generate content from a prompt with an AI provider.
109    ///
110    /// It attempts the primary model first, and uses the fallback model if the primary is rate-limited.
111    async fn generate_content<T>(&self, sys_prompt: &str, user_prompt: &str) -> Result<T>
112    where
113        T: DeserializeOwned + JsonSchema + Debug,
114    {
115        // First, try with the primary model
116        let primary_result = self.execute_request(self.primary, sys_prompt, user_prompt).await;
117
118        // Check if the primary attempt failed with a rate limit error
119        if let Err(AppError::UserFacing(UserFacingError::AiRateLimit)) = &primary_result {
120            // If it's a rate limit error and we have a fallback model, try again with it
121            if let Some(fallback) = self.fallback {
122                tracing::warn!(
123                    "Primary model ({}) rate-limited, retrying with fallback ({})",
124                    self.primary_alias,
125                    self.fallback_alias
126                );
127                return self.execute_request(fallback, sys_prompt, user_prompt).await;
128            }
129        }
130
131        // Otherwise, return the result of the primary attempt
132        primary_result
133    }
134
135    /// Executes a single AI content generation request against a specific model configuration
136    #[instrument(skip_all, fields(provider = config.provider().provider_name()))]
137    async fn execute_request<T>(&self, config: &AiModelConfig, sys_prompt: &str, user_prompt: &str) -> Result<T>
138    where
139        T: DeserializeOwned + JsonSchema + Debug,
140    {
141        let provider = config.provider();
142
143        // Generate the json schema from the expected type
144        let json_schema = build_json_schema_for::<T>()?;
145
146        // Prepare the request body
147        let mut req_builder = provider.build_request(&self.inner, sys_prompt, user_prompt, &json_schema);
148
149        // Include auth header for this config
150        let api_key_env = provider.api_key_env_var_name();
151        if let Ok(api_key) = env::var(api_key_env) {
152            let (header_name, header_value) = provider.auth_header(api_key);
153            let mut header_value =
154                HeaderValue::from_str(&header_value).wrap_err_with(|| format!("Invalid '{api_key_env}' value"))?;
155            header_value.set_sensitive(true);
156            req_builder = req_builder.header(header_name, header_value);
157        }
158
159        // Build the request
160        let req = req_builder.build().wrap_err("Couldn't build api request")?;
161
162        // Call the API
163        tracing::debug!("Calling {} API: {}", provider.provider_name(), req.url());
164        let res = self.inner.execute(req).await.map_err(|err| {
165            if err.is_timeout() {
166                tracing::error!("Request timeout: {err:?}");
167                UserFacingError::AiRequestTimeout
168            } else if err.is_connect() {
169                tracing::error!("Couldn't connect to the API: {err:?}");
170                UserFacingError::AiRequestFailed(String::from("error connecting to the provider"))
171            } else {
172                tracing::error!("Couldn't perform the request: {err:?}");
173                UserFacingError::AiRequestFailed(err.to_string())
174            }
175        })?;
176
177        // Check the response status
178        if !res.status().is_success() {
179            let status = res.status();
180            let status_str = status.as_str();
181            let body = res.text().await.unwrap_or_default();
182            if status == StatusCode::UNAUTHORIZED || status == StatusCode::FORBIDDEN {
183                tracing::warn!(
184                    "Got response [{status_str}] {}",
185                    status.canonical_reason().unwrap_or_default()
186                );
187                tracing::debug!("{body}");
188                return Err(
189                    UserFacingError::AiMissingOrInvalidApiKey(provider.api_key_env_var_name().to_string()).into(),
190                );
191            } else if status == StatusCode::TOO_MANY_REQUESTS {
192                tracing::info!("Got response [{status_str}] Too Many Requests");
193                tracing::debug!("{body}");
194                return Err(UserFacingError::AiRateLimit.into());
195            } else if status == StatusCode::BAD_REQUEST {
196                tracing::error!("Got response [{status_str}] Bad Request:\n{body}");
197                return Err(eyre!("Bad request while fetching {} API:\n{body}", provider.provider_name()).into());
198            } else if let Some(reason) = status.canonical_reason() {
199                tracing::error!("Got response [{status_str}] {reason}:\n{body}");
200                return Err(
201                    UserFacingError::AiRequestFailed(format!("received {status_str} {reason} response")).into(),
202                );
203            } else {
204                tracing::error!("Got response [{status_str}]:\n{body}");
205                return Err(UserFacingError::AiRequestFailed(format!("received {status_str} response")).into());
206            }
207        }
208
209        // Parse successful response
210        match &config {
211            AiModelConfig::Openai(conf) => conf.parse_response(res).await,
212            AiModelConfig::Gemini(conf) => conf.parse_response(res).await,
213            AiModelConfig::Anthropic(conf) => conf.parse_response(res).await,
214            AiModelConfig::Ollama(conf) => conf.parse_response(res).await,
215        }
216    }
217}
218
219#[derive(Debug, Deserialize, JsonSchema)]
220pub struct CommandSuggestions {
221    /// The list of suggested commands for the user to choose from
222    pub suggestions: Vec<CommandSuggestion>,
223}
224
225/// A structured object representing a suggestion for a shell command and its explanation
226#[derive(Debug, Deserialize, JsonSchema)]
227pub struct CommandSuggestion {
228    /// A clear, concise, human-readable explanation of the generated command, usually a single sentence.
229    /// This description is for the end-user to help them understand the command before executing it.
230    pub description: String,
231    /// The command template string. Use `{{variable-name}}` syntax for any placeholders that require user input.
232    /// For ephemeral values like commit messages or sensitive values like API keys or passwords, use the triple-brace
233    /// syntax `{{{variable-name}}}`.
234    pub command: String,
235}
236
237/// A structured object to propose a fix to a failed command
238#[derive(Debug, Deserialize, JsonSchema)]
239pub struct CommandFix {
240    /// A very brief, 2-5 word summary of the error category.
241    /// Examples: "Command Not Found", "Permission Denied", "Invalid Argument", "Git Typo".
242    pub summary: String,
243    /// A detailed, human-readable explanation of the root cause of the error.
244    /// This section should explain *what* went wrong and *why*, based on the provided error message,
245    /// but should not contain the solution itself.
246    pub diagnosis: String,
247    /// A human-readable string describing the recommended next steps.
248    /// This can be a description of a fix, diagnostic command(s) to run, or a suggested workaround.
249    pub proposal: String,
250    /// The corrected, valid, ready-to-execute command string if the error was a simple typo or syntax issue.
251    /// This field should only be populated if a direct command correction is the primary solution.
252    /// Example: "git status"
253    pub fixed_command: String,
254}
255
256/// Build the json schema for the given type, including `additionalProperties: false`
257fn build_json_schema_for<T: JsonSchema>() -> Result<Schema> {
258    // Generate the derived schema
259    let mut schema = schema_for!(T);
260
261    // The schema must be an object, for most LLMs to support it
262    let root = schema.as_object_mut().wrap_err("The type must be an object")?;
263    root.insert("additionalProperties".into(), false.into());
264
265    // If there's any additional object definition, also update the additionalProperties
266    if let Some(defs) = root.get_mut("$defs") {
267        for definition in defs.as_object_mut().wrap_err("Expected objects at $defs")?.values_mut() {
268            if let Some(def_obj) = definition.as_object_mut()
269                && def_obj.get("type").and_then(|t| t.as_str()) == Some("object")
270            {
271                def_obj.insert("additionalProperties".into(), false.into());
272            }
273        }
274    }
275
276    Ok(schema)
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_command_suggestions_schema() {
285        let schema = build_json_schema_for::<CommandSuggestions>().unwrap();
286        println!("{}", serde_json::to_string_pretty(&schema).unwrap());
287    }
288
289    #[test]
290    fn test_command_fix_schema() {
291        let schema = build_json_schema_for::<CommandFix>().unwrap();
292        println!("{}", serde_json::to_string_pretty(&schema).unwrap());
293    }
294}