intelli_shell/ai/
mod.rs

1use std::{env, fmt::Debug, time::Duration};
2
3use color_eyre::eyre::{Context, ContextCompat, eyre};
4use reqwest::{
5    Client, ClientBuilder, RequestBuilder, Response, StatusCode,
6    header::{self, HeaderMap, HeaderName, HeaderValue},
7};
8use schemars::{JsonSchema, Schema, schema_for};
9use serde::{Deserialize, de::DeserializeOwned};
10use tokio_util::sync::CancellationToken;
11use tracing::instrument;
12
13use crate::{
14    config::AiModelConfig,
15    errors::{AppError, Result, UserFacingError},
16};
17
18mod anthropic;
19mod gemini;
20mod ollama;
21mod openai;
22
23/// A trait that defines the provider-specific logic for the generic [`AiClient`]
24pub trait AiProviderBase: Send + Sync {
25    /// The name of the provider
26    fn provider_name(&self) -> &'static str;
27
28    /// Returns the header name and value to authenticate the given api key
29    fn auth_header(&self, api_key: String) -> (HeaderName, String);
30
31    /// The name of the environent variable expected to have the api key
32    fn api_key_env_var_name(&self) -> &str;
33
34    /// Build the provider-specific request
35    fn build_request(
36        &self,
37        client: &Client,
38        sys_prompt: &str,
39        user_prompt: &str,
40        json_schema: &Schema,
41    ) -> RequestBuilder;
42}
43
44/// A trait that defines the provider-specific logic for the generic [`AiClient`]
45#[trait_variant::make(Send)]
46pub trait AiProvider: AiProviderBase {
47    /// Parse the provider-specific response
48    async fn parse_response<T>(&self, res: Response) -> Result<T>
49    where
50        Self: Sized,
51        T: DeserializeOwned + JsonSchema + Debug;
52}
53
54/// A generic client to communicate with AI providers
55#[cfg_attr(debug_assertions, derive(Debug))]
56pub struct AiClient<'a> {
57    inner: Client,
58    primary_alias: &'a str,
59    primary: &'a AiModelConfig,
60    fallback_alias: &'a str,
61    fallback: Option<&'a AiModelConfig>,
62}
63impl<'a> AiClient<'a> {
64    /// Creates a new AI client with a primary and an optional fallback model configuration
65    pub fn new(
66        primary_alias: &'a str,
67        primary: &'a AiModelConfig,
68        fallback_alias: &'a str,
69        fallback: Option<&'a AiModelConfig>,
70    ) -> Result<Self> {
71        // Construct the base headers for all requests
72        let mut headers = HeaderMap::new();
73        headers.append(header::CONTENT_TYPE, HeaderValue::from_static("application/json"));
74
75        // Build the reqwest client
76        let inner = ClientBuilder::new()
77            .connect_timeout(Duration::from_secs(5))
78            .timeout(Duration::from_secs(5 * 60))
79            .user_agent("intelli-shell")
80            .default_headers(headers)
81            .build()
82            .wrap_err("Couldn't build AI client")?;
83
84        Ok(AiClient {
85            inner,
86            primary_alias,
87            primary,
88            fallback_alias,
89            fallback,
90        })
91    }
92
93    /// Generate some command suggestions based on the given prompt
94    #[instrument(skip_all)]
95    pub async fn generate_command_suggestions(
96        &self,
97        sys_prompt: &str,
98        user_prompt: &str,
99        cancellation_token: CancellationToken,
100    ) -> Result<CommandSuggestions> {
101        self.generate_content(sys_prompt, user_prompt, cancellation_token).await
102    }
103
104    /// Generate a command fix based on the given prompt
105    #[instrument(skip_all)]
106    pub async fn generate_command_fix(
107        &self,
108        sys_prompt: &str,
109        user_prompt: &str,
110        cancellation_token: CancellationToken,
111    ) -> Result<CommandFix> {
112        self.generate_content(sys_prompt, user_prompt, cancellation_token).await
113    }
114
115    /// Generate a command for a dynamic variable completion
116    #[instrument(skip_all)]
117    pub async fn generate_completion_suggestion(
118        &self,
119        sys_prompt: &str,
120        user_prompt: &str,
121        cancellation_token: CancellationToken,
122    ) -> Result<VariableCompletionSuggestion> {
123        self.generate_content(sys_prompt, user_prompt, cancellation_token).await
124    }
125
126    /// The inner logic to generate content from a prompt with an AI provider.
127    ///
128    /// It attempts the primary model first, and uses the fallback model if the primary is rate-limited.
129    async fn generate_content<T>(
130        &self,
131        sys_prompt: &str,
132        user_prompt: &str,
133        cancellation_token: CancellationToken,
134    ) -> Result<T>
135    where
136        T: DeserializeOwned + JsonSchema + Debug,
137    {
138        // First, try with the primary model
139        let primary_result = self
140            .execute_request(self.primary, sys_prompt, user_prompt, cancellation_token.clone())
141            .await;
142
143        // Check if the primary attempt failed with a rate limit error
144        if let Err(AppError::UserFacing(UserFacingError::AiRateLimit)) = &primary_result {
145            // If it's a rate limit error and we have a fallback model, try again with it
146            if let Some(fallback) = self.fallback {
147                tracing::warn!(
148                    "Primary model ({}) rate-limited, retrying with fallback ({})",
149                    self.primary_alias,
150                    self.fallback_alias
151                );
152                return self
153                    .execute_request(fallback, sys_prompt, user_prompt, cancellation_token)
154                    .await;
155            }
156        }
157
158        // Check if the primary attempt failed with a service unavailable error
159        if let Err(AppError::UserFacing(UserFacingError::AiUnavailable)) = &primary_result {
160            // Some APIs respond this status when a specific model is overloaded, so we try with the fallback
161            if let Some(fallback) = self.fallback {
162                tracing::warn!(
163                    "Primary model ({}) unavailable, retrying with fallback ({})",
164                    self.primary_alias,
165                    self.fallback_alias
166                );
167                return self
168                    .execute_request(fallback, sys_prompt, user_prompt, cancellation_token)
169                    .await;
170            }
171        }
172
173        // Otherwise, return the result of the primary attempt
174        primary_result
175    }
176
177    /// Executes a single AI content generation request against a specific model configuration
178    #[instrument(skip_all, fields(provider = config.provider().provider_name()))]
179    async fn execute_request<T>(
180        &self,
181        config: &AiModelConfig,
182        sys_prompt: &str,
183        user_prompt: &str,
184        cancellation_token: CancellationToken,
185    ) -> Result<T>
186    where
187        T: DeserializeOwned + JsonSchema + Debug,
188    {
189        let provider = config.provider();
190
191        // Generate the json schema from the expected type
192        let json_schema = build_json_schema_for::<T>()?;
193
194        // Prepare the request body
195        let mut req_builder = provider.build_request(&self.inner, sys_prompt, user_prompt, &json_schema);
196
197        // Include auth header for this config
198        let api_key_env = provider.api_key_env_var_name();
199        if let Ok(api_key) = env::var(api_key_env) {
200            let (header_name, header_value) = provider.auth_header(api_key);
201            let mut header_value =
202                HeaderValue::from_str(&header_value).wrap_err_with(|| format!("Invalid '{api_key_env}' value"))?;
203            header_value.set_sensitive(true);
204            req_builder = req_builder.header(header_name, header_value);
205        }
206
207        // Build the request
208        let req = req_builder.build().wrap_err("Couldn't build api request")?;
209
210        // Call the API
211        tracing::debug!("Calling {} API: {}", provider.provider_name(), req.url());
212        let res = tokio::select! {
213            biased;
214            _ = cancellation_token.cancelled() => {
215                return Err(UserFacingError::Cancelled.into());
216            }
217            res = self.inner.execute(req) => {
218                res.map_err(|err| {
219                    if err.is_timeout() {
220                        tracing::error!("Request timeout: {err:?}");
221                        UserFacingError::AiRequestTimeout
222                    } else if err.is_connect() {
223                        tracing::error!("Couldn't connect to the API: {err:?}");
224                        UserFacingError::AiRequestFailed(String::from("error connecting to the provider"))
225                    } else {
226                        tracing::error!("Couldn't perform the request: {err:?}");
227                        UserFacingError::AiRequestFailed(err.to_string())
228                    }
229                })?
230            }
231        };
232
233        // Check the response status
234        if !res.status().is_success() {
235            let status = res.status();
236            let status_str = status.as_str();
237            let body = res.text().await.unwrap_or_default();
238            if status == StatusCode::UNAUTHORIZED || status == StatusCode::FORBIDDEN {
239                tracing::warn!(
240                    "Got response [{status_str}] {}",
241                    status.canonical_reason().unwrap_or_default()
242                );
243                tracing::debug!("{body}");
244                return Err(
245                    UserFacingError::AiMissingOrInvalidApiKey(provider.api_key_env_var_name().to_string()).into(),
246                );
247            } else if status == StatusCode::TOO_MANY_REQUESTS {
248                tracing::info!("Got response [{status_str}] Too Many Requests");
249                tracing::debug!("{body}");
250                return Err(UserFacingError::AiRateLimit.into());
251            } else if status == StatusCode::SERVICE_UNAVAILABLE {
252                tracing::info!("Got response [{status_str}] Service Unavailable");
253                tracing::debug!("{body}");
254                return Err(UserFacingError::AiUnavailable.into());
255            } else if status == StatusCode::BAD_REQUEST {
256                tracing::error!("Got response [{status_str}] Bad Request:\n{body}");
257                return Err(eyre!("Bad request while fetching {} API:\n{body}", provider.provider_name()).into());
258            } else if let Some(reason) = status.canonical_reason() {
259                tracing::error!("Got response [{status_str}] {reason}:\n{body}");
260                return Err(
261                    UserFacingError::AiRequestFailed(format!("received {status_str} {reason} response")).into(),
262                );
263            } else {
264                tracing::error!("Got response [{status_str}]:\n{body}");
265                return Err(UserFacingError::AiRequestFailed(format!("received {status_str} response")).into());
266            }
267        }
268
269        // Parse successful response
270        match &config {
271            AiModelConfig::Openai(conf) => conf.parse_response(res).await,
272            AiModelConfig::Gemini(conf) => conf.parse_response(res).await,
273            AiModelConfig::Anthropic(conf) => conf.parse_response(res).await,
274            AiModelConfig::Ollama(conf) => conf.parse_response(res).await,
275        }
276    }
277}
278
279#[derive(Debug, Deserialize, JsonSchema)]
280pub struct CommandSuggestions {
281    /// The list of suggested commands for the user to choose from
282    pub suggestions: Vec<CommandSuggestion>,
283}
284
285/// A structured object representing a suggestion for a shell command and its explanation
286#[derive(Debug, Deserialize, JsonSchema)]
287pub struct CommandSuggestion {
288    /// A clear, concise, human-readable explanation of the generated command, usually a single sentence.
289    /// This description is for the end-user to help them understand the command before executing it.
290    pub description: String,
291    /// The command template string. Use `{{variable-name}}` syntax for any placeholders that require user input.
292    /// For ephemeral values like commit messages or sensitive values like API keys or passwords, use the triple-brace
293    /// syntax `{{{variable-name}}}`.
294    pub command: String,
295}
296
297/// A structured object to propose a fix to a failed command
298#[derive(Debug, Deserialize, JsonSchema)]
299pub struct CommandFix {
300    /// A very brief, 2-5 word summary of the error category.
301    /// Examples: "Command Not Found", "Permission Denied", "Invalid Argument", "Git Typo".
302    pub summary: String,
303    /// A detailed, human-readable explanation of the root cause of the error.
304    /// This section should explain *what* went wrong and *why*, based on the provided error message,
305    /// but should not contain the solution itself.
306    pub diagnosis: String,
307    /// A human-readable string describing the recommended next steps.
308    /// This can be a description of a fix, diagnostic command(s) to run, or a suggested workaround.
309    pub proposal: String,
310    /// The corrected, valid, ready-to-execute command string if the error was a simple typo or syntax issue.
311    /// This field should only be populated if a direct command correction is the primary solution.
312    /// Example: "git status"
313    pub fixed_command: String,
314}
315
316/// A structured object to propose a command for a dynamic variable completion
317#[derive(Debug, Deserialize, JsonSchema)]
318pub struct VariableCompletionSuggestion {
319    /// The shell command that generates the suggestion values when executed
320    pub command: String,
321}
322
323/// Build the json schema for the given type, including `additionalProperties: false`
324fn build_json_schema_for<T: JsonSchema>() -> Result<Schema> {
325    // Generate the derived schema
326    let mut schema = schema_for!(T);
327
328    // The schema must be an object, for most LLMs to support it
329    let root = schema.as_object_mut().wrap_err("The type must be an object")?;
330    root.insert("additionalProperties".into(), false.into());
331
332    // If there's any additional object definition, also update the additionalProperties
333    if let Some(defs) = root.get_mut("$defs") {
334        for definition in defs.as_object_mut().wrap_err("Expected objects at $defs")?.values_mut() {
335            if let Some(def_obj) = definition.as_object_mut()
336                && def_obj.get("type").and_then(|t| t.as_str()) == Some("object")
337            {
338                def_obj.insert("additionalProperties".into(), false.into());
339            }
340        }
341    }
342
343    Ok(schema)
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn test_command_suggestions_schema() {
352        let schema = build_json_schema_for::<CommandSuggestions>().unwrap();
353        println!("{}", serde_json::to_string_pretty(&schema).unwrap());
354    }
355
356    #[test]
357    fn test_command_fix_schema() {
358        let schema = build_json_schema_for::<CommandFix>().unwrap();
359        println!("{}", serde_json::to_string_pretty(&schema).unwrap());
360    }
361}