intelli_shell/ai/
openai.rs

1use std::fmt::Debug;
2
3use color_eyre::eyre::Context;
4use reqwest::{
5    Client, RequestBuilder, Response,
6    header::{self, HeaderName},
7};
8use schemars::{JsonSchema, Schema};
9use serde::{Deserialize, de::DeserializeOwned};
10use serde_json::{Value as Json, json};
11
12use super::{AiProvider, AiProviderBase};
13use crate::{
14    config::OpenAiModelConfig,
15    errors::{Result, UserFacingError},
16};
17
18impl AiProviderBase for OpenAiModelConfig {
19    fn provider_name(&self) -> &'static str {
20        "OpenAI"
21    }
22
23    fn auth_header(&self, api_key: String) -> (HeaderName, String) {
24        (header::AUTHORIZATION, format!("Bearer {api_key}"))
25    }
26
27    fn api_key_env_var_name(&self) -> &str {
28        &self.api_key_env
29    }
30
31    fn build_request(
32        &self,
33        client: &Client,
34        sys_prompt: &str,
35        user_prompt: &str,
36        json_schema: &Schema,
37    ) -> RequestBuilder {
38        // Request body
39        // https://platform.openai.com/docs/api-reference/chat/create
40        let request_body = json!({
41            "model": self.model,
42            "messages": [
43                {
44                    "role": "system",
45                    "content": sys_prompt
46                },
47                {
48                    "role": "user",
49                    "content": user_prompt
50                }
51            ],
52            "response_format": {
53                "type": "json_schema",
54                "json_schema": {
55                    "name": "command_suggestions",
56                    "strict": true,
57                    "schema": json_schema
58                }
59            }
60        });
61
62        tracing::trace!("Request:\n{request_body:#}");
63
64        // Chat completions url
65        let url = format!("{}/chat/completions", self.url);
66
67        // Request
68        client.post(url).json(&request_body)
69    }
70}
71
72impl AiProvider for OpenAiModelConfig {
73    async fn parse_response<T>(&self, res: Response) -> Result<T>
74    where
75        T: DeserializeOwned + JsonSchema + Debug,
76    {
77        // Parse successful response
78        let res: Json = res.json().await.wrap_err("OpenAI response not a json")?;
79        tracing::trace!("Response:\n{res:#}");
80        let mut res: OpenAiResponse = serde_json::from_value(res).wrap_err("Couldn't parse OpenAI response")?;
81
82        // Validate the response content
83        if res.choices.is_empty() {
84            tracing::error!("Response got no choices: {res:?}");
85            return Err(UserFacingError::AiRequestFailed(String::from("received response with no choices")).into());
86        } else if res.choices.len() > 1 {
87            tracing::warn!("Response got {} choices", res.choices.len());
88        }
89
90        let choice = res.choices.remove(0);
91        if choice.finish_reason != "stop" {
92            tracing::error!("OpenAI response got an invalid finish reason: {}", choice.finish_reason);
93            return Err(UserFacingError::AiRequestFailed(format!(
94                "couldn't generate a valid response: {}",
95                choice.finish_reason
96            ))
97            .into());
98        }
99
100        if let Some(refusal) = choice.message.refusal
101            && !refusal.is_empty()
102        {
103            tracing::error!("OpenAI refused to answer: {refusal}");
104            return Err(UserFacingError::AiRequestFailed(format!("response refused: {refusal}")).into());
105        }
106
107        let Some(message) = choice.message.content.filter(|c| !c.trim().is_empty()) else {
108            tracing::error!("OpenAI returned an empty response");
109            return Err(UserFacingError::AiRequestFailed(String::from("received an empty response")).into());
110        };
111
112        // Parse the message
113        Ok(serde_json::from_str(&message).map_err(|err| {
114            tracing::error!("Couldn't parse API response into the expected format: {err}\nMessage:\n{message}");
115            UserFacingError::AiRequestFailed(String::from("couldn't parse api response into the expected format"))
116        })?)
117    }
118}
119
120#[derive(Debug, Deserialize)]
121struct OpenAiResponse {
122    choices: Vec<OpenAiChoice>,
123}
124
125#[derive(Debug, Deserialize)]
126struct OpenAiChoice {
127    message: OpenAiResponseMessage,
128    finish_reason: String,
129}
130
131#[derive(Debug, Deserialize)]
132struct OpenAiResponseMessage {
133    #[serde(default)]
134    refusal: Option<String>,
135    #[serde(default)]
136    content: Option<String>,
137}
138
139#[cfg(test)]
140mod tests {
141    use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
142
143    use super::*;
144    use crate::{ai::AiClient, config::AiModelConfig};
145
146    #[tokio::test]
147    #[ignore] // Real API calls require valid api keys
148    async fn test_openai_api() -> Result<()> {
149        tracing_subscriber::registry()
150            .with(tracing_subscriber::fmt::layer().compact())
151            .init();
152        let config = AiModelConfig::Openai(OpenAiModelConfig {
153            model: "gpt-4.1-nano".into(),
154            url: "https://api.openai.com/v1".into(),
155            api_key_env: "OPENAI_API_KEY".into(),
156        });
157        let client = AiClient::new("test", &config, "", None)?;
158        let res = client
159            .generate_command_suggestions(
160                "you're a cli expert, that will proide command suggestions based on what the user want to do",
161                "undo last n amount of commits",
162            )
163            .await?;
164        tracing::info!("Suggestions:");
165        for command in res.suggestions {
166            tracing::info!("  # {}", command.description);
167            tracing::info!("  {}", command.command);
168        }
169        Ok(())
170    }
171}