1use std::{env, fmt::Debug, time::Duration};
2
3use color_eyre::eyre::{Context, ContextCompat, eyre};
4use reqwest::{
5 Client, ClientBuilder, RequestBuilder, Response, StatusCode,
6 header::{self, HeaderMap, HeaderName, HeaderValue},
7};
8use schemars::{JsonSchema, Schema, schema_for};
9use serde::{Deserialize, de::DeserializeOwned};
10use tracing::instrument;
11
12use crate::{
13 config::AiModelConfig,
14 errors::{AppError, Result, UserFacingError},
15};
16
17mod anthropic;
18mod gemini;
19mod ollama;
20mod openai;
21
22pub trait AiProviderBase: Send + Sync {
24 fn provider_name(&self) -> &'static str;
26
27 fn auth_header(&self, api_key: String) -> (HeaderName, String);
29
30 fn api_key_env_var_name(&self) -> &str;
32
33 fn build_request(
35 &self,
36 client: &Client,
37 sys_prompt: &str,
38 user_prompt: &str,
39 json_schema: &Schema,
40 ) -> RequestBuilder;
41}
42
43#[trait_variant::make(Send)]
45pub trait AiProvider: AiProviderBase {
46 async fn parse_response<T>(&self, res: Response) -> Result<T>
48 where
49 Self: Sized,
50 T: DeserializeOwned + JsonSchema + Debug;
51}
52
53#[cfg_attr(debug_assertions, derive(Debug))]
55pub struct AiClient<'a> {
56 inner: Client,
57 primary_alias: &'a str,
58 primary: &'a AiModelConfig,
59 fallback_alias: &'a str,
60 fallback: Option<&'a AiModelConfig>,
61}
62impl<'a> AiClient<'a> {
63 pub fn new(
65 primary_alias: &'a str,
66 primary: &'a AiModelConfig,
67 fallback_alias: &'a str,
68 fallback: Option<&'a AiModelConfig>,
69 ) -> Result<Self> {
70 let mut headers = HeaderMap::new();
72 headers.append(header::CONTENT_TYPE, HeaderValue::from_static("application/json"));
73
74 let inner = ClientBuilder::new()
76 .connect_timeout(Duration::from_secs(5))
77 .timeout(Duration::from_secs(5 * 60))
78 .user_agent("intelli-shell")
79 .default_headers(headers)
80 .build()
81 .wrap_err("Couldn't build AI client")?;
82
83 Ok(AiClient {
84 inner,
85 primary_alias,
86 primary,
87 fallback_alias,
88 fallback,
89 })
90 }
91
92 #[instrument(skip_all)]
94 pub async fn generate_command_suggestions(
95 &self,
96 sys_prompt: &str,
97 user_prompt: &str,
98 ) -> Result<CommandSuggestions> {
99 self.generate_content(sys_prompt, user_prompt).await
100 }
101
102 #[instrument(skip_all)]
104 pub async fn generate_command_fix(&self, sys_prompt: &str, user_prompt: &str) -> Result<CommandFix> {
105 self.generate_content(sys_prompt, user_prompt).await
106 }
107
108 #[instrument(skip_all)]
110 pub async fn generate_completion_suggestion(
111 &self,
112 sys_prompt: &str,
113 user_prompt: &str,
114 ) -> Result<VariableCompletionSuggestion> {
115 self.generate_content(sys_prompt, user_prompt).await
116 }
117
118 async fn generate_content<T>(&self, sys_prompt: &str, user_prompt: &str) -> Result<T>
122 where
123 T: DeserializeOwned + JsonSchema + Debug,
124 {
125 let primary_result = self.execute_request(self.primary, sys_prompt, user_prompt).await;
127
128 if let Err(AppError::UserFacing(UserFacingError::AiRateLimit)) = &primary_result {
130 if let Some(fallback) = self.fallback {
132 tracing::warn!(
133 "Primary model ({}) rate-limited, retrying with fallback ({})",
134 self.primary_alias,
135 self.fallback_alias
136 );
137 return self.execute_request(fallback, sys_prompt, user_prompt).await;
138 }
139 }
140
141 if let Err(AppError::UserFacing(UserFacingError::AiUnavailable)) = &primary_result {
143 if let Some(fallback) = self.fallback {
145 tracing::warn!(
146 "Primary model ({}) unavailable, retrying with fallback ({})",
147 self.primary_alias,
148 self.fallback_alias
149 );
150 return self.execute_request(fallback, sys_prompt, user_prompt).await;
151 }
152 }
153
154 primary_result
156 }
157
158 #[instrument(skip_all, fields(provider = config.provider().provider_name()))]
160 async fn execute_request<T>(&self, config: &AiModelConfig, sys_prompt: &str, user_prompt: &str) -> Result<T>
161 where
162 T: DeserializeOwned + JsonSchema + Debug,
163 {
164 let provider = config.provider();
165
166 let json_schema = build_json_schema_for::<T>()?;
168
169 let mut req_builder = provider.build_request(&self.inner, sys_prompt, user_prompt, &json_schema);
171
172 let api_key_env = provider.api_key_env_var_name();
174 if let Ok(api_key) = env::var(api_key_env) {
175 let (header_name, header_value) = provider.auth_header(api_key);
176 let mut header_value =
177 HeaderValue::from_str(&header_value).wrap_err_with(|| format!("Invalid '{api_key_env}' value"))?;
178 header_value.set_sensitive(true);
179 req_builder = req_builder.header(header_name, header_value);
180 }
181
182 let req = req_builder.build().wrap_err("Couldn't build api request")?;
184
185 tracing::debug!("Calling {} API: {}", provider.provider_name(), req.url());
187 let res = self.inner.execute(req).await.map_err(|err| {
188 if err.is_timeout() {
189 tracing::error!("Request timeout: {err:?}");
190 UserFacingError::AiRequestTimeout
191 } else if err.is_connect() {
192 tracing::error!("Couldn't connect to the API: {err:?}");
193 UserFacingError::AiRequestFailed(String::from("error connecting to the provider"))
194 } else {
195 tracing::error!("Couldn't perform the request: {err:?}");
196 UserFacingError::AiRequestFailed(err.to_string())
197 }
198 })?;
199
200 if !res.status().is_success() {
202 let status = res.status();
203 let status_str = status.as_str();
204 let body = res.text().await.unwrap_or_default();
205 if status == StatusCode::UNAUTHORIZED || status == StatusCode::FORBIDDEN {
206 tracing::warn!(
207 "Got response [{status_str}] {}",
208 status.canonical_reason().unwrap_or_default()
209 );
210 tracing::debug!("{body}");
211 return Err(
212 UserFacingError::AiMissingOrInvalidApiKey(provider.api_key_env_var_name().to_string()).into(),
213 );
214 } else if status == StatusCode::TOO_MANY_REQUESTS {
215 tracing::info!("Got response [{status_str}] Too Many Requests");
216 tracing::debug!("{body}");
217 return Err(UserFacingError::AiRateLimit.into());
218 } else if status == StatusCode::SERVICE_UNAVAILABLE {
219 tracing::info!("Got response [{status_str}] Service Unavailable");
220 tracing::debug!("{body}");
221 return Err(UserFacingError::AiUnavailable.into());
222 } else if status == StatusCode::BAD_REQUEST {
223 tracing::error!("Got response [{status_str}] Bad Request:\n{body}");
224 return Err(eyre!("Bad request while fetching {} API:\n{body}", provider.provider_name()).into());
225 } else if let Some(reason) = status.canonical_reason() {
226 tracing::error!("Got response [{status_str}] {reason}:\n{body}");
227 return Err(
228 UserFacingError::AiRequestFailed(format!("received {status_str} {reason} response")).into(),
229 );
230 } else {
231 tracing::error!("Got response [{status_str}]:\n{body}");
232 return Err(UserFacingError::AiRequestFailed(format!("received {status_str} response")).into());
233 }
234 }
235
236 match &config {
238 AiModelConfig::Openai(conf) => conf.parse_response(res).await,
239 AiModelConfig::Gemini(conf) => conf.parse_response(res).await,
240 AiModelConfig::Anthropic(conf) => conf.parse_response(res).await,
241 AiModelConfig::Ollama(conf) => conf.parse_response(res).await,
242 }
243 }
244}
245
246#[derive(Debug, Deserialize, JsonSchema)]
247pub struct CommandSuggestions {
248 pub suggestions: Vec<CommandSuggestion>,
250}
251
252#[derive(Debug, Deserialize, JsonSchema)]
254pub struct CommandSuggestion {
255 pub description: String,
258 pub command: String,
262}
263
264#[derive(Debug, Deserialize, JsonSchema)]
266pub struct CommandFix {
267 pub summary: String,
270 pub diagnosis: String,
274 pub proposal: String,
277 pub fixed_command: String,
281}
282
283#[derive(Debug, Deserialize, JsonSchema)]
285pub struct VariableCompletionSuggestion {
286 pub command: String,
288}
289
290fn build_json_schema_for<T: JsonSchema>() -> Result<Schema> {
292 let mut schema = schema_for!(T);
294
295 let root = schema.as_object_mut().wrap_err("The type must be an object")?;
297 root.insert("additionalProperties".into(), false.into());
298
299 if let Some(defs) = root.get_mut("$defs") {
301 for definition in defs.as_object_mut().wrap_err("Expected objects at $defs")?.values_mut() {
302 if let Some(def_obj) = definition.as_object_mut()
303 && def_obj.get("type").and_then(|t| t.as_str()) == Some("object")
304 {
305 def_obj.insert("additionalProperties".into(), false.into());
306 }
307 }
308 }
309
310 Ok(schema)
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_command_suggestions_schema() {
319 let schema = build_json_schema_for::<CommandSuggestions>().unwrap();
320 println!("{}", serde_json::to_string_pretty(&schema).unwrap());
321 }
322
323 #[test]
324 fn test_command_fix_schema() {
325 let schema = build_json_schema_for::<CommandFix>().unwrap();
326 println!("{}", serde_json::to_string_pretty(&schema).unwrap());
327 }
328}