intelli_shell/ai/
mod.rs

1use std::{env, fmt::Debug, time::Duration};
2
3use color_eyre::eyre::{Context, ContextCompat, eyre};
4use reqwest::{
5    Client, ClientBuilder, RequestBuilder, Response, StatusCode,
6    header::{self, HeaderMap, HeaderName, HeaderValue},
7};
8use schemars::{JsonSchema, Schema, schema_for};
9use serde::{Deserialize, de::DeserializeOwned};
10use tracing::instrument;
11
12use crate::{
13    config::AiModelConfig,
14    errors::{AppError, Result, UserFacingError},
15};
16
17mod anthropic;
18mod gemini;
19mod ollama;
20mod openai;
21
22/// A trait that defines the provider-specific logic for the generic [`AiClient`]
23pub trait AiProviderBase: Send + Sync {
24    /// The name of the provider
25    fn provider_name(&self) -> &'static str;
26
27    /// Returns the header name and value to authenticate the given api key
28    fn auth_header(&self, api_key: String) -> (HeaderName, String);
29
30    /// The name of the environent variable expected to have the api key
31    fn api_key_env_var_name(&self) -> &str;
32
33    /// Build the provider-specific request
34    fn build_request(
35        &self,
36        client: &Client,
37        sys_prompt: &str,
38        user_prompt: &str,
39        json_schema: &Schema,
40    ) -> RequestBuilder;
41}
42
43/// A trait that defines the provider-specific logic for the generic [`AiClient`]
44#[trait_variant::make(Send)]
45pub trait AiProvider: AiProviderBase {
46    /// Parse the provider-specific response
47    async fn parse_response<T>(&self, res: Response) -> Result<T>
48    where
49        Self: Sized,
50        T: DeserializeOwned + JsonSchema + Debug;
51}
52
53/// A generic client to communicate with AI providers
54#[cfg_attr(debug_assertions, derive(Debug))]
55pub struct AiClient<'a> {
56    inner: Client,
57    primary_alias: &'a str,
58    primary: &'a AiModelConfig,
59    fallback_alias: &'a str,
60    fallback: Option<&'a AiModelConfig>,
61}
62impl<'a> AiClient<'a> {
63    /// Creates a new AI client with a primary and an optional fallback model configuration
64    pub fn new(
65        primary_alias: &'a str,
66        primary: &'a AiModelConfig,
67        fallback_alias: &'a str,
68        fallback: Option<&'a AiModelConfig>,
69    ) -> Result<Self> {
70        // Construct the base headers for all requests
71        let mut headers = HeaderMap::new();
72        headers.append(header::CONTENT_TYPE, HeaderValue::from_static("application/json"));
73
74        // Build the reqwest client
75        let inner = ClientBuilder::new()
76            .connect_timeout(Duration::from_secs(5))
77            .timeout(Duration::from_secs(5 * 60))
78            .user_agent("intelli-shell")
79            .default_headers(headers)
80            .build()
81            .wrap_err("Couldn't build AI client")?;
82
83        Ok(AiClient {
84            inner,
85            primary_alias,
86            primary,
87            fallback_alias,
88            fallback,
89        })
90    }
91
92    /// Generate some command suggestions based on the given prompt
93    #[instrument(skip_all)]
94    pub async fn generate_command_suggestions(
95        &self,
96        sys_prompt: &str,
97        user_prompt: &str,
98    ) -> Result<CommandSuggestions> {
99        self.generate_content(sys_prompt, user_prompt).await
100    }
101
102    /// Generate a command fix based on the given prompt
103    #[instrument(skip_all)]
104    pub async fn generate_command_fix(&self, sys_prompt: &str, user_prompt: &str) -> Result<CommandFix> {
105        self.generate_content(sys_prompt, user_prompt).await
106    }
107
108    /// Generate a command for a dynamic variable completion
109    #[instrument(skip_all)]
110    pub async fn generate_completion_suggestion(
111        &self,
112        sys_prompt: &str,
113        user_prompt: &str,
114    ) -> Result<VariableCompletionSuggestion> {
115        self.generate_content(sys_prompt, user_prompt).await
116    }
117
118    /// The inner logic to generate content from a prompt with an AI provider.
119    ///
120    /// It attempts the primary model first, and uses the fallback model if the primary is rate-limited.
121    async fn generate_content<T>(&self, sys_prompt: &str, user_prompt: &str) -> Result<T>
122    where
123        T: DeserializeOwned + JsonSchema + Debug,
124    {
125        // First, try with the primary model
126        let primary_result = self.execute_request(self.primary, sys_prompt, user_prompt).await;
127
128        // Check if the primary attempt failed with a rate limit error
129        if let Err(AppError::UserFacing(UserFacingError::AiRateLimit)) = &primary_result {
130            // If it's a rate limit error and we have a fallback model, try again with it
131            if let Some(fallback) = self.fallback {
132                tracing::warn!(
133                    "Primary model ({}) rate-limited, retrying with fallback ({})",
134                    self.primary_alias,
135                    self.fallback_alias
136                );
137                return self.execute_request(fallback, sys_prompt, user_prompt).await;
138            }
139        }
140
141        // Check if the primary attempt failed with a service unavailable error
142        if let Err(AppError::UserFacing(UserFacingError::AiUnavailable)) = &primary_result {
143            // Some APIs respond this status when a specific model is overloaded, so we try with the fallback
144            if let Some(fallback) = self.fallback {
145                tracing::warn!(
146                    "Primary model ({}) unavailable, retrying with fallback ({})",
147                    self.primary_alias,
148                    self.fallback_alias
149                );
150                return self.execute_request(fallback, sys_prompt, user_prompt).await;
151            }
152        }
153
154        // Otherwise, return the result of the primary attempt
155        primary_result
156    }
157
158    /// Executes a single AI content generation request against a specific model configuration
159    #[instrument(skip_all, fields(provider = config.provider().provider_name()))]
160    async fn execute_request<T>(&self, config: &AiModelConfig, sys_prompt: &str, user_prompt: &str) -> Result<T>
161    where
162        T: DeserializeOwned + JsonSchema + Debug,
163    {
164        let provider = config.provider();
165
166        // Generate the json schema from the expected type
167        let json_schema = build_json_schema_for::<T>()?;
168
169        // Prepare the request body
170        let mut req_builder = provider.build_request(&self.inner, sys_prompt, user_prompt, &json_schema);
171
172        // Include auth header for this config
173        let api_key_env = provider.api_key_env_var_name();
174        if let Ok(api_key) = env::var(api_key_env) {
175            let (header_name, header_value) = provider.auth_header(api_key);
176            let mut header_value =
177                HeaderValue::from_str(&header_value).wrap_err_with(|| format!("Invalid '{api_key_env}' value"))?;
178            header_value.set_sensitive(true);
179            req_builder = req_builder.header(header_name, header_value);
180        }
181
182        // Build the request
183        let req = req_builder.build().wrap_err("Couldn't build api request")?;
184
185        // Call the API
186        tracing::debug!("Calling {} API: {}", provider.provider_name(), req.url());
187        let res = self.inner.execute(req).await.map_err(|err| {
188            if err.is_timeout() {
189                tracing::error!("Request timeout: {err:?}");
190                UserFacingError::AiRequestTimeout
191            } else if err.is_connect() {
192                tracing::error!("Couldn't connect to the API: {err:?}");
193                UserFacingError::AiRequestFailed(String::from("error connecting to the provider"))
194            } else {
195                tracing::error!("Couldn't perform the request: {err:?}");
196                UserFacingError::AiRequestFailed(err.to_string())
197            }
198        })?;
199
200        // Check the response status
201        if !res.status().is_success() {
202            let status = res.status();
203            let status_str = status.as_str();
204            let body = res.text().await.unwrap_or_default();
205            if status == StatusCode::UNAUTHORIZED || status == StatusCode::FORBIDDEN {
206                tracing::warn!(
207                    "Got response [{status_str}] {}",
208                    status.canonical_reason().unwrap_or_default()
209                );
210                tracing::debug!("{body}");
211                return Err(
212                    UserFacingError::AiMissingOrInvalidApiKey(provider.api_key_env_var_name().to_string()).into(),
213                );
214            } else if status == StatusCode::TOO_MANY_REQUESTS {
215                tracing::info!("Got response [{status_str}] Too Many Requests");
216                tracing::debug!("{body}");
217                return Err(UserFacingError::AiRateLimit.into());
218            } else if status == StatusCode::SERVICE_UNAVAILABLE {
219                tracing::info!("Got response [{status_str}] Service Unavailable");
220                tracing::debug!("{body}");
221                return Err(UserFacingError::AiUnavailable.into());
222            } else if status == StatusCode::BAD_REQUEST {
223                tracing::error!("Got response [{status_str}] Bad Request:\n{body}");
224                return Err(eyre!("Bad request while fetching {} API:\n{body}", provider.provider_name()).into());
225            } else if let Some(reason) = status.canonical_reason() {
226                tracing::error!("Got response [{status_str}] {reason}:\n{body}");
227                return Err(
228                    UserFacingError::AiRequestFailed(format!("received {status_str} {reason} response")).into(),
229                );
230            } else {
231                tracing::error!("Got response [{status_str}]:\n{body}");
232                return Err(UserFacingError::AiRequestFailed(format!("received {status_str} response")).into());
233            }
234        }
235
236        // Parse successful response
237        match &config {
238            AiModelConfig::Openai(conf) => conf.parse_response(res).await,
239            AiModelConfig::Gemini(conf) => conf.parse_response(res).await,
240            AiModelConfig::Anthropic(conf) => conf.parse_response(res).await,
241            AiModelConfig::Ollama(conf) => conf.parse_response(res).await,
242        }
243    }
244}
245
246#[derive(Debug, Deserialize, JsonSchema)]
247pub struct CommandSuggestions {
248    /// The list of suggested commands for the user to choose from
249    pub suggestions: Vec<CommandSuggestion>,
250}
251
252/// A structured object representing a suggestion for a shell command and its explanation
253#[derive(Debug, Deserialize, JsonSchema)]
254pub struct CommandSuggestion {
255    /// A clear, concise, human-readable explanation of the generated command, usually a single sentence.
256    /// This description is for the end-user to help them understand the command before executing it.
257    pub description: String,
258    /// The command template string. Use `{{variable-name}}` syntax for any placeholders that require user input.
259    /// For ephemeral values like commit messages or sensitive values like API keys or passwords, use the triple-brace
260    /// syntax `{{{variable-name}}}`.
261    pub command: String,
262}
263
264/// A structured object to propose a fix to a failed command
265#[derive(Debug, Deserialize, JsonSchema)]
266pub struct CommandFix {
267    /// A very brief, 2-5 word summary of the error category.
268    /// Examples: "Command Not Found", "Permission Denied", "Invalid Argument", "Git Typo".
269    pub summary: String,
270    /// A detailed, human-readable explanation of the root cause of the error.
271    /// This section should explain *what* went wrong and *why*, based on the provided error message,
272    /// but should not contain the solution itself.
273    pub diagnosis: String,
274    /// A human-readable string describing the recommended next steps.
275    /// This can be a description of a fix, diagnostic command(s) to run, or a suggested workaround.
276    pub proposal: String,
277    /// The corrected, valid, ready-to-execute command string if the error was a simple typo or syntax issue.
278    /// This field should only be populated if a direct command correction is the primary solution.
279    /// Example: "git status"
280    pub fixed_command: String,
281}
282
283/// A structured object to propose a command for a dynamic variable completion
284#[derive(Debug, Deserialize, JsonSchema)]
285pub struct VariableCompletionSuggestion {
286    /// The shell command that generates the suggestion values when executed
287    pub command: String,
288}
289
290/// Build the json schema for the given type, including `additionalProperties: false`
291fn build_json_schema_for<T: JsonSchema>() -> Result<Schema> {
292    // Generate the derived schema
293    let mut schema = schema_for!(T);
294
295    // The schema must be an object, for most LLMs to support it
296    let root = schema.as_object_mut().wrap_err("The type must be an object")?;
297    root.insert("additionalProperties".into(), false.into());
298
299    // If there's any additional object definition, also update the additionalProperties
300    if let Some(defs) = root.get_mut("$defs") {
301        for definition in defs.as_object_mut().wrap_err("Expected objects at $defs")?.values_mut() {
302            if let Some(def_obj) = definition.as_object_mut()
303                && def_obj.get("type").and_then(|t| t.as_str()) == Some("object")
304            {
305                def_obj.insert("additionalProperties".into(), false.into());
306            }
307        }
308    }
309
310    Ok(schema)
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_command_suggestions_schema() {
319        let schema = build_json_schema_for::<CommandSuggestions>().unwrap();
320        println!("{}", serde_json::to_string_pretty(&schema).unwrap());
321    }
322
323    #[test]
324    fn test_command_fix_schema() {
325        let schema = build_json_schema_for::<CommandFix>().unwrap();
326        println!("{}", serde_json::to_string_pretty(&schema).unwrap());
327    }
328}