Skip to main content

rusty_commit/providers/
xai.rs

1use anyhow::{Context, Result};
2use async_openai::{
3    config::OpenAIConfig,
4    types::chat::{
5        ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
6        CreateChatCompletionRequestArgs,
7    },
8    Client,
9};
10use async_trait::async_trait;
11
12use super::{split_prompt, AIProvider};
13use crate::config::accounts::AccountConfig;
14use crate::config::Config;
15
16pub struct XAIProvider {
17    client: Client<OpenAIConfig>,
18    model: String,
19}
20
21impl XAIProvider {
22    pub fn new(config: &Config) -> Result<Self> {
23        let api_key = config
24            .api_key
25            .as_ref()
26            .context("xAI API key not configured.\nRun: rco config set RCO_API_KEY=<your_key>\nGet your API key from: https://x.ai/api")?;
27
28        let openai_config = OpenAIConfig::new()
29            .with_api_key(api_key)
30            .with_api_base(config.api_url.as_deref().unwrap_or("https://api.x.ai/v1"));
31
32        let client = Client::with_config(openai_config);
33        let model = config.model.as_deref().unwrap_or("grok-beta").to_string();
34
35        Ok(Self { client, model })
36    }
37
38    /// Create provider from account configuration
39    #[allow(dead_code)]
40    pub fn from_account(account: &AccountConfig, api_key: &str, config: &Config) -> Result<Self> {
41        let openai_config = OpenAIConfig::new().with_api_key(api_key).with_api_base(
42            account
43                .api_url
44                .as_deref()
45                .or(config.api_url.as_deref())
46                .unwrap_or("https://api.x.ai/v1"),
47        );
48
49        let client = Client::with_config(openai_config);
50        let model = account
51            .model
52            .as_deref()
53            .or(config.model.as_deref())
54            .unwrap_or("grok-beta")
55            .to_string();
56
57        Ok(Self { client, model })
58    }
59}
60
61#[async_trait]
62impl AIProvider for XAIProvider {
63    async fn generate_commit_message(
64        &self,
65        diff: &str,
66        context: Option<&str>,
67        full_gitmoji: bool,
68        config: &Config,
69    ) -> Result<String> {
70        let (system_prompt, user_prompt) = split_prompt(diff, context, config, full_gitmoji);
71
72        let messages = vec![
73            ChatCompletionRequestSystemMessage::from(system_prompt).into(),
74            ChatCompletionRequestUserMessage::from(user_prompt).into(),
75        ];
76
77        let request = CreateChatCompletionRequestArgs::default()
78            .model(&self.model)
79            .messages(messages)
80            .temperature(0.7)
81            .max_tokens(config.tokens_max_output.unwrap_or(500) as u16)
82            .build()?;
83
84        let response = self
85            .client
86            .chat()
87            .create(request)
88            .await
89            .context("Failed to generate commit message from xAI")?;
90
91        let message = response
92            .choices
93            .first()
94            .and_then(|choice| choice.message.content.as_ref())
95            .context("xAI returned an empty response")?
96            .trim()
97            .to_string();
98
99        Ok(message)
100    }
101}
102
103/// ProviderBuilder for XAI
104pub struct XAIProviderBuilder;
105
106impl super::registry::ProviderBuilder for XAIProviderBuilder {
107    fn name(&self) -> &'static str {
108        "xai"
109    }
110
111    fn aliases(&self) -> Vec<&'static str> {
112        vec!["grok", "x-ai"]
113    }
114
115    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
116        Ok(Box::new(XAIProvider::new(config)?))
117    }
118
119    fn requires_api_key(&self) -> bool {
120        true
121    }
122
123    fn default_model(&self) -> Option<&'static str> {
124        Some("grok-beta")
125    }
126}