Skip to main content

devsper_providers/
github.rs

1use devsper_core::{LlmProvider, LlmRequest, LlmResponse, LlmRole, StopReason};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8const BASE_URL: &str = "https://models.github.com/inference/v1";
9
10/// GitHub Models API provider — OpenAI-compatible, Bearer token auth.
11/// Expects model names prefixed with "github:" (e.g. "github:gpt-4o").
12pub struct GithubModelsProvider {
13    client: Client,
14    token: String,
15}
16
17impl GithubModelsProvider {
18    pub fn new(token: impl Into<String>) -> Self {
19        Self {
20            client: Client::new(),
21            token: token.into(),
22        }
23    }
24}
25
26#[derive(Serialize)]
27struct OaiRequest<'a> {
28    model: &'a str,
29    messages: Vec<OaiMessage<'a>>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    max_tokens: Option<u32>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    temperature: Option<f32>,
34}
35
36#[derive(Serialize)]
37struct OaiMessage<'a> {
38    role: &'a str,
39    content: &'a str,
40}
41
42#[derive(Deserialize)]
43struct OaiResponse {
44    choices: Vec<OaiChoice>,
45    usage: OaiUsage,
46    model: String,
47}
48
49#[derive(Deserialize)]
50struct OaiChoice {
51    message: OaiChoiceMessage,
52    finish_reason: Option<String>,
53}
54
55#[derive(Deserialize)]
56struct OaiChoiceMessage {
57    content: Option<String>,
58}
59
60#[derive(Deserialize)]
61struct OaiUsage {
62    prompt_tokens: u32,
63    completion_tokens: u32,
64}
65
66fn role_str(role: &LlmRole) -> &'static str {
67    match role {
68        LlmRole::System => "system",
69        LlmRole::User | LlmRole::Tool => "user",
70        LlmRole::Assistant => "assistant",
71    }
72}
73
74#[async_trait]
75impl LlmProvider for GithubModelsProvider {
76    async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
77        use tracing::Instrument;
78
79        let span = tracing::info_span!(
80            "gen_ai.chat",
81            "gen_ai.system" = self.name(),
82            "gen_ai.operation.name" = "chat",
83            "gen_ai.request.model" = req.model.as_str(),
84            "gen_ai.request.max_tokens" = req.max_tokens,
85            "gen_ai.response.model" = tracing::field::Empty,
86            "gen_ai.usage.input_tokens" = tracing::field::Empty,
87            "gen_ai.usage.output_tokens" = tracing::field::Empty,
88        );
89
90        // Strip "github:" prefix before sending to API
91        let model = req.model.strip_prefix("github:").unwrap_or(&req.model);
92
93        let messages: Vec<OaiMessage> = req
94            .messages
95            .iter()
96            .map(|m| OaiMessage {
97                role: role_str(&m.role),
98                content: &m.content,
99            })
100            .collect();
101
102        let body = OaiRequest {
103            model,
104            messages,
105            max_tokens: req.max_tokens,
106            temperature: req.temperature,
107        };
108
109        debug!(model = %model, provider = "github-models", "GitHub Models request");
110
111        let result = async {
112            let resp = self
113                .client
114                .post(format!("{BASE_URL}/chat/completions"))
115                .header("Authorization", format!("Bearer {}", self.token))
116                .header("Content-Type", "application/json")
117                .json(&body)
118                .send()
119                .await?;
120
121            if !resp.status().is_success() {
122                let status = resp.status();
123                let text = resp.text().await.unwrap_or_default();
124                return Err(anyhow!("github-models API error {status}: {text}"));
125            }
126
127            let data: OaiResponse = resp.json().await?;
128            let choice = data
129                .choices
130                .into_iter()
131                .next()
132                .ok_or_else(|| anyhow!("No choices in response"))?;
133
134            let stop_reason = match choice.finish_reason.as_deref() {
135                Some("tool_calls") => StopReason::ToolUse,
136                Some("length") => StopReason::MaxTokens,
137                Some("stop") | None => StopReason::EndTurn,
138                _ => StopReason::EndTurn,
139            };
140
141            Ok(LlmResponse {
142                content: choice.message.content.unwrap_or_default(),
143                tool_calls: vec![],
144                input_tokens: data.usage.prompt_tokens,
145                output_tokens: data.usage.completion_tokens,
146                model: data.model,
147                stop_reason,
148            })
149        }
150        .instrument(span.clone())
151        .await;
152
153        if let Ok(ref resp) = result {
154            span.record("gen_ai.response.model", resp.model.as_str());
155            span.record("gen_ai.usage.input_tokens", resp.input_tokens);
156            span.record("gen_ai.usage.output_tokens", resp.output_tokens);
157        }
158        result
159    }
160
161    fn name(&self) -> &str {
162        "github-models"
163    }
164
165    fn supports_model(&self, model: &str) -> bool {
166        model.starts_with("github:")
167    }
168}