intelli_shell/ai/
anthropic.rs

1use std::fmt::Debug;
2
3use color_eyre::eyre::Context;
4use reqwest::{Client, RequestBuilder, Response, header::HeaderName};
5use schemars::{JsonSchema, Schema};
6use serde::{Deserialize, de::DeserializeOwned};
7use serde_json::{Value as Json, json};
8
9use super::{AiProvider, AiProviderBase};
10use crate::{
11    config::AnthropicModelConfig,
12    errors::{Result, UserFacingError},
13};
14
15const TOOL_NAME: &str = "propose_response";
16
17impl AiProviderBase for AnthropicModelConfig {
18    fn provider_name(&self) -> &'static str {
19        "Anthropic"
20    }
21
22    fn auth_header(&self, api_key: String) -> (HeaderName, String) {
23        (HeaderName::from_static("x-api-key"), api_key)
24    }
25
26    fn api_key_env_var_name(&self) -> &str {
27        &self.api_key_env
28    }
29
30    fn build_request(
31        &self,
32        client: &Client,
33        sys_prompt: &str,
34        user_prompt: &str,
35        json_schema: &Schema,
36    ) -> RequestBuilder {
37        // Request body
38        // https://docs.anthropic.com/en/api/messages
39        let request_body = json!({
40            "model": self.model,
41            "system": sys_prompt,
42            "messages": [
43                {
44                    "role": "user",
45                    "content": user_prompt
46                }
47            ],
48            "max_tokens": 4096,
49            "tools": [{
50                "name": TOOL_NAME,
51                "description": "Propose an structured response to the end user",
52                "input_schema": json_schema,
53            }],
54            "tool_choice": {
55                "type": "tool",
56                "name": TOOL_NAME,
57                "disable_parallel_tool_use": true
58            }
59        });
60
61        tracing::trace!("Request:\n{request_body:#}");
62
63        // Messages url
64        let url = format!("{}/messages", self.url);
65
66        // Request
67        client
68            .post(url)
69            .header("anthropic-version", "2023-06-01")
70            .json(&request_body)
71    }
72}
73
74impl AiProvider for AnthropicModelConfig {
75    async fn parse_response<T>(&self, res: Response) -> Result<T>
76    where
77        T: DeserializeOwned + JsonSchema + Debug,
78    {
79        // Parse successful response
80        let res: Json = res.json().await.wrap_err("Anthropic response not a json")?;
81        tracing::trace!("Response:\n{res:#}");
82        let mut res: AnthropicResponse<T> =
83            serde_json::from_value(res).wrap_err("Couldn't parse Anthropic response")?;
84
85        // Validate the response content
86        if res.stop_reason != "end_turn" && res.stop_reason != "tool_use" {
87            tracing::error!("OpenAI response got an invalid stop reason: {}", res.stop_reason);
88            return Err(UserFacingError::AiRequestFailed(format!(
89                "couldn't generate a valid response: {}",
90                res.stop_reason
91            ))
92            .into());
93        }
94
95        if res.content.is_empty() {
96            tracing::error!("Response got no content: {res:?}");
97            return Err(UserFacingError::AiRequestFailed(String::from("received response with no content")).into());
98        } else if res.content.len() > 1 {
99            tracing::warn!("Response got {} content blocks", res.content.len());
100        }
101
102        let block = res.content.remove(0);
103        if block.r#type != "tool_use" {
104            tracing::error!("Anthropic response got an invalid content type: {}", block.r#type);
105            return Err(UserFacingError::AiRequestFailed(format!("unexpected response type: {}", block.r#type)).into());
106        }
107
108        if block.name != TOOL_NAME {
109            tracing::error!("Anthropic response got an invalid tool name: {}", block.name);
110            return Err(UserFacingError::AiRequestFailed(format!("received invalid tool name: {}", block.name)).into());
111        }
112
113        Ok(block.input)
114    }
115}
116
117#[derive(Debug, Deserialize)]
118struct AnthropicResponse<T> {
119    content: Vec<ContentBlock<T>>,
120    stop_reason: String,
121}
122
123#[derive(Debug, Deserialize)]
124struct ContentBlock<T> {
125    r#type: String,
126    name: String,
127    input: T,
128}
129
130#[cfg(test)]
131mod tests {
132    use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
133
134    use super::*;
135    use crate::{ai::AiClient, config::AiModelConfig};
136
137    #[tokio::test]
138    #[ignore] // Real API calls require valid api keys
139    async fn test_anthropic_api() -> Result<()> {
140        tracing_subscriber::registry()
141            .with(tracing_subscriber::fmt::layer().compact())
142            .init();
143        let config = AiModelConfig::Anthropic(AnthropicModelConfig {
144            model: "claude-sonnet-4-0".into(),
145            url: "https://api.anthropic.com/v1".into(),
146            api_key_env: "ANTHROPIC_API_KEY".into(),
147        });
148        let client = AiClient::new("test", &config, "", None)?;
149        let res = client
150            .generate_command_suggestions(
151                "you're a cli expert, that will proide command suggestions based on what the user want to do",
152                "undo last n amount of commits",
153            )
154            .await?;
155        tracing::info!("Suggestions:");
156        for command in res.suggestions {
157            tracing::info!("  # {}", command.description);
158            tracing::info!("  {}", command.command);
159        }
160        Ok(())
161    }
162}