devsper_providers/
github.rs1use devsper_core::{LlmProvider, LlmRequest, LlmResponse, LlmRole, StopReason};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8const BASE_URL: &str = "https://models.github.com/inference/v1";
9
10pub struct GithubModelsProvider {
13 client: Client,
14 token: String,
15}
16
17impl GithubModelsProvider {
18 pub fn new(token: impl Into<String>) -> Self {
19 Self {
20 client: Client::new(),
21 token: token.into(),
22 }
23 }
24}
25
26#[derive(Serialize)]
27struct OaiRequest<'a> {
28 model: &'a str,
29 messages: Vec<OaiMessage<'a>>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 max_tokens: Option<u32>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 temperature: Option<f32>,
34}
35
36#[derive(Serialize)]
37struct OaiMessage<'a> {
38 role: &'a str,
39 content: &'a str,
40}
41
42#[derive(Deserialize)]
43struct OaiResponse {
44 choices: Vec<OaiChoice>,
45 usage: OaiUsage,
46 model: String,
47}
48
49#[derive(Deserialize)]
50struct OaiChoice {
51 message: OaiChoiceMessage,
52 finish_reason: Option<String>,
53}
54
55#[derive(Deserialize)]
56struct OaiChoiceMessage {
57 content: Option<String>,
58}
59
60#[derive(Deserialize)]
61struct OaiUsage {
62 prompt_tokens: u32,
63 completion_tokens: u32,
64}
65
66fn role_str(role: &LlmRole) -> &'static str {
67 match role {
68 LlmRole::System => "system",
69 LlmRole::User | LlmRole::Tool => "user",
70 LlmRole::Assistant => "assistant",
71 }
72}
73
74#[async_trait]
75impl LlmProvider for GithubModelsProvider {
76 async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
77 use tracing::Instrument;
78
79 let span = tracing::info_span!(
80 "gen_ai.chat",
81 "gen_ai.system" = self.name(),
82 "gen_ai.operation.name" = "chat",
83 "gen_ai.request.model" = req.model.as_str(),
84 "gen_ai.request.max_tokens" = req.max_tokens,
85 "gen_ai.response.model" = tracing::field::Empty,
86 "gen_ai.usage.input_tokens" = tracing::field::Empty,
87 "gen_ai.usage.output_tokens" = tracing::field::Empty,
88 );
89
90 let model = req.model.strip_prefix("github:").unwrap_or(&req.model);
92
93 let messages: Vec<OaiMessage> = req
94 .messages
95 .iter()
96 .map(|m| OaiMessage {
97 role: role_str(&m.role),
98 content: &m.content,
99 })
100 .collect();
101
102 let body = OaiRequest {
103 model,
104 messages,
105 max_tokens: req.max_tokens,
106 temperature: req.temperature,
107 };
108
109 debug!(model = %model, provider = "github-models", "GitHub Models request");
110
111 let result = async {
112 let resp = self
113 .client
114 .post(format!("{BASE_URL}/chat/completions"))
115 .header("Authorization", format!("Bearer {}", self.token))
116 .header("Content-Type", "application/json")
117 .json(&body)
118 .send()
119 .await?;
120
121 if !resp.status().is_success() {
122 let status = resp.status();
123 let text = resp.text().await.unwrap_or_default();
124 return Err(anyhow!("github-models API error {status}: {text}"));
125 }
126
127 let data: OaiResponse = resp.json().await?;
128 let choice = data
129 .choices
130 .into_iter()
131 .next()
132 .ok_or_else(|| anyhow!("No choices in response"))?;
133
134 let stop_reason = match choice.finish_reason.as_deref() {
135 Some("tool_calls") => StopReason::ToolUse,
136 Some("length") => StopReason::MaxTokens,
137 Some("stop") | None => StopReason::EndTurn,
138 _ => StopReason::EndTurn,
139 };
140
141 Ok(LlmResponse {
142 content: choice.message.content.unwrap_or_default(),
143 tool_calls: vec![],
144 input_tokens: data.usage.prompt_tokens,
145 output_tokens: data.usage.completion_tokens,
146 model: data.model,
147 stop_reason,
148 })
149 }
150 .instrument(span.clone())
151 .await;
152
153 if let Ok(ref resp) = result {
154 span.record("gen_ai.response.model", resp.model.as_str());
155 span.record("gen_ai.usage.input_tokens", resp.input_tokens);
156 span.record("gen_ai.usage.output_tokens", resp.output_tokens);
157 }
158 result
159 }
160
161 fn name(&self) -> &str {
162 "github-models"
163 }
164
165 fn supports_model(&self, model: &str) -> bool {
166 model.starts_with("github:")
167 }
168}