1use anyhow::Result;
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::time::Duration;
6
7use super::api_client::{ApiClient, AuthMethod, AuthProvider};
8use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage};
9use super::embedding::EmbeddingCapable;
10use super::errors::ProviderError;
11use super::formats::databricks::{create_request, response_to_message};
12use super::oauth;
13use super::retry::ProviderRetry;
14use super::utils::{
15 get_model, handle_response_openai_compat, map_http_error_to_provider_error,
16 stream_openai_compat, ImageFormat, RequestLog,
17};
18use crate::config::ConfigError;
19use crate::conversation::message::Message;
20use crate::model::ModelConfig;
21use crate::providers::formats::openai::get_usage;
22use crate::providers::retry::{
23 RetryConfig, DEFAULT_BACKOFF_MULTIPLIER, DEFAULT_INITIAL_RETRY_INTERVAL_MS,
24 DEFAULT_MAX_RETRIES, DEFAULT_MAX_RETRY_INTERVAL_MS,
25};
26use rmcp::model::Tool;
27use serde_json::json;
28
29const DEFAULT_CLIENT_ID: &str = "databricks-cli";
30const DEFAULT_REDIRECT_URL: &str = "http://localhost";
31const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"];
32const DEFAULT_TIMEOUT_SECS: u64 = 600;
33
34pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-sonnet-4";
35const DATABRICKS_DEFAULT_FAST_MODEL: &str = "gemini-2-5-flash";
36pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
37 "databricks-claude-sonnet-4-5",
38 "databricks-claude-3-7-sonnet",
39 "databricks-meta-llama-3-3-70b-instruct",
40 "databricks-meta-llama-3-1-405b-instruct",
41 "databricks-dbrx-instruct",
42];
43
44pub const DATABRICKS_DOC_URL: &str =
45 "https://docs.databricks.com/en/generative-ai/external-models/index.html";
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub enum DatabricksAuth {
49 Token(String),
50 OAuth {
51 host: String,
52 client_id: String,
53 redirect_url: String,
54 scopes: Vec<String>,
55 },
56}
57
58impl DatabricksAuth {
59 pub fn oauth(host: String) -> Self {
60 Self::OAuth {
61 host,
62 client_id: DEFAULT_CLIENT_ID.to_string(),
63 redirect_url: DEFAULT_REDIRECT_URL.to_string(),
64 scopes: DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect(),
65 }
66 }
67
68 pub fn token(token: String) -> Self {
69 Self::Token(token)
70 }
71}
72
73struct DatabricksAuthProvider {
74 auth: DatabricksAuth,
75}
76
77#[async_trait]
78impl AuthProvider for DatabricksAuthProvider {
79 async fn get_auth_header(&self) -> Result<(String, String)> {
80 let token = match &self.auth {
81 DatabricksAuth::Token(token) => token.clone(),
82 DatabricksAuth::OAuth {
83 host,
84 client_id,
85 redirect_url,
86 scopes,
87 } => oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?,
88 };
89 Ok(("Authorization".to_string(), format!("Bearer {}", token)))
90 }
91}
92
93#[derive(Debug, serde::Serialize)]
94pub struct DatabricksProvider {
95 #[serde(skip)]
96 api_client: ApiClient,
97 auth: DatabricksAuth,
98 model: ModelConfig,
99 image_format: ImageFormat,
100 #[serde(skip)]
101 retry_config: RetryConfig,
102 #[serde(skip)]
103 name: String,
104}
105
106impl DatabricksProvider {
107 pub async fn from_env(model: ModelConfig) -> Result<Self> {
108 let config = crate::config::Config::global();
109
110 let mut host: Result<String, ConfigError> = config.get_param("DATABRICKS_HOST");
111 if host.is_err() {
112 host = config.get_secret("DATABRICKS_HOST")
113 }
114
115 if host.is_err() {
116 return Err(ConfigError::NotFound(
117 "Did not find DATABRICKS_HOST in either config file or keyring".to_string(),
118 )
119 .into());
120 }
121
122 let host = host?;
123 let retry_config = Self::load_retry_config(config);
124
125 let auth = if let Ok(api_key) = config.get_secret("DATABRICKS_TOKEN") {
126 DatabricksAuth::token(api_key)
127 } else {
128 DatabricksAuth::oauth(host.clone())
129 };
130
131 let auth_method =
132 AuthMethod::Custom(Box::new(DatabricksAuthProvider { auth: auth.clone() }));
133
134 let api_client =
135 ApiClient::with_timeout(host, auth_method, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
136
137 let mut provider = Self {
139 api_client,
140 auth,
141 model: model.clone(),
142 image_format: ImageFormat::OpenAi,
143 retry_config,
144 name: Self::metadata().name,
145 };
146
147 let model_with_fast = if let Ok(Some(models)) = provider.fetch_supported_models().await {
149 if models.contains(&DATABRICKS_DEFAULT_FAST_MODEL.to_string()) {
150 tracing::debug!(
151 "Found {} in Databricks workspace, setting as fast model",
152 DATABRICKS_DEFAULT_FAST_MODEL
153 );
154 model.with_fast(DATABRICKS_DEFAULT_FAST_MODEL.to_string())
155 } else {
156 tracing::debug!(
157 "{} not found in Databricks workspace, not setting fast model",
158 DATABRICKS_DEFAULT_FAST_MODEL
159 );
160 model
161 }
162 } else {
163 tracing::debug!("Could not fetch Databricks models, not setting fast model");
164 model
165 };
166
167 provider.model = model_with_fast;
168 Ok(provider)
169 }
170
171 fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
172 let max_retries = config
173 .get_param("DATABRICKS_MAX_RETRIES")
174 .ok()
175 .and_then(|v: String| v.parse::<usize>().ok())
176 .unwrap_or(DEFAULT_MAX_RETRIES);
177
178 let initial_interval_ms = config
179 .get_param("DATABRICKS_INITIAL_RETRY_INTERVAL_MS")
180 .ok()
181 .and_then(|v: String| v.parse::<u64>().ok())
182 .unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS);
183
184 let backoff_multiplier = config
185 .get_param("DATABRICKS_BACKOFF_MULTIPLIER")
186 .ok()
187 .and_then(|v: String| v.parse::<f64>().ok())
188 .unwrap_or(DEFAULT_BACKOFF_MULTIPLIER);
189
190 let max_interval_ms = config
191 .get_param("DATABRICKS_MAX_RETRY_INTERVAL_MS")
192 .ok()
193 .and_then(|v: String| v.parse::<u64>().ok())
194 .unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS);
195
196 RetryConfig {
197 max_retries,
198 initial_interval_ms,
199 backoff_multiplier,
200 max_interval_ms,
201 }
202 }
203
204 pub fn from_params(host: String, api_key: String, model: ModelConfig) -> Result<Self> {
205 let auth = DatabricksAuth::token(api_key);
206 let auth_method =
207 AuthMethod::Custom(Box::new(DatabricksAuthProvider { auth: auth.clone() }));
208
209 let api_client = ApiClient::with_timeout(host, auth_method, Duration::from_secs(600))?;
210
211 Ok(Self {
212 api_client,
213 auth,
214 model,
215 image_format: ImageFormat::OpenAi,
216 retry_config: RetryConfig::default(),
217 name: Self::metadata().name,
218 })
219 }
220
221 fn get_endpoint_path(&self, model_name: &str, is_embedding: bool) -> String {
222 if is_embedding {
223 "serving-endpoints/text-embedding-3-small/invocations".to_string()
224 } else {
225 format!("serving-endpoints/{}/invocations", model_name)
226 }
227 }
228
229 async fn post(&self, payload: Value, model_name: Option<&str>) -> Result<Value, ProviderError> {
230 let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none();
231 let model_to_use = model_name.unwrap_or(&self.model.model_name);
232 let path = self.get_endpoint_path(model_to_use, is_embedding);
233
234 let response = self.api_client.response_post(&path, &payload).await?;
235 handle_response_openai_compat(response).await
236 }
237}
238
239#[async_trait]
240impl Provider for DatabricksProvider {
241 fn metadata() -> ProviderMetadata {
242 ProviderMetadata::new(
243 "databricks",
244 "Databricks",
245 "Models on Databricks AI Gateway",
246 DATABRICKS_DEFAULT_MODEL,
247 DATABRICKS_KNOWN_MODELS.to_vec(),
248 DATABRICKS_DOC_URL,
249 vec![
250 ConfigKey::new("DATABRICKS_HOST", true, false, None),
251 ConfigKey::new("DATABRICKS_TOKEN", false, true, None),
252 ],
253 )
254 }
255
256 fn get_name(&self) -> &str {
257 &self.name
258 }
259
260 fn retry_config(&self) -> RetryConfig {
261 self.retry_config.clone()
262 }
263
264 fn get_model_config(&self) -> ModelConfig {
265 self.model.clone()
266 }
267
268 #[tracing::instrument(
269 skip(self, model_config, system, messages, tools),
270 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
271 )]
272 async fn complete_with_model(
273 &self,
274 model_config: &ModelConfig,
275 system: &str,
276 messages: &[Message],
277 tools: &[Tool],
278 ) -> Result<(Message, ProviderUsage), ProviderError> {
279 let mut payload =
280 create_request(model_config, system, messages, tools, &self.image_format)?;
281 payload
282 .as_object_mut()
283 .expect("payload should have model key")
284 .remove("model");
285
286 let mut log = RequestLog::start(&self.model, &payload)?;
287
288 let response = self
289 .with_retry(|| self.post(payload.clone(), Some(&model_config.model_name)))
290 .await?;
291
292 let message = response_to_message(&response)?;
293 let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
294 tracing::debug!("Failed to get usage data");
295 Usage::default()
296 });
297 let response_model = get_model(&response);
298 log.write(&response, Some(&usage))?;
299
300 Ok((message, ProviderUsage::new(response_model, usage)))
301 }
302
303 async fn stream(
304 &self,
305 system: &str,
306 messages: &[Message],
307 tools: &[Tool],
308 ) -> Result<MessageStream, ProviderError> {
309 let model_config = self.model.clone();
310
311 let mut payload =
312 create_request(&model_config, system, messages, tools, &self.image_format)?;
313 payload
314 .as_object_mut()
315 .expect("payload should have model key")
316 .remove("model");
317
318 payload
319 .as_object_mut()
320 .unwrap()
321 .insert("stream".to_string(), Value::Bool(true));
322
323 let path = self.get_endpoint_path(&model_config.model_name, false);
324 let mut log = RequestLog::start(&self.model, &payload)?;
325 let response = self
326 .with_retry(|| async {
327 let resp = self.api_client.response_post(&path, &payload).await?;
328 if !resp.status().is_success() {
329 let status = resp.status();
330 let error_text = resp.text().await.unwrap_or_default();
331
332 let json_payload = serde_json::from_str::<Value>(&error_text).ok();
334 return Err(map_http_error_to_provider_error(status, json_payload));
335 }
336 Ok(resp)
337 })
338 .await
339 .inspect_err(|e| {
340 let _ = log.error(e);
341 })?;
342
343 stream_openai_compat(response, log)
344 }
345
346 fn supports_streaming(&self) -> bool {
347 true
348 }
349
350 fn supports_embeddings(&self) -> bool {
351 true
352 }
353
354 async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
355 EmbeddingCapable::create_embeddings(self, texts)
356 .await
357 .map_err(|e| ProviderError::ExecutionError(e.to_string()))
358 }
359
360 async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
361 let response = match self
362 .api_client
363 .response_get("api/2.0/serving-endpoints")
364 .await
365 {
366 Ok(resp) => resp,
367 Err(e) => {
368 tracing::warn!("Failed to fetch Databricks models: {}", e);
369 return Ok(None);
370 }
371 };
372
373 if !response.status().is_success() {
374 let status = response.status();
375 if let Ok(error_text) = response.text().await {
376 tracing::warn!(
377 "Failed to fetch Databricks models: {} - {}",
378 status,
379 error_text
380 );
381 } else {
382 tracing::warn!("Failed to fetch Databricks models: {}", status);
383 }
384 return Ok(None);
385 }
386
387 let json: Value = match response.json().await {
388 Ok(json) => json,
389 Err(e) => {
390 tracing::warn!("Failed to parse Databricks API response: {}", e);
391 return Ok(None);
392 }
393 };
394
395 let endpoints = match json.get("endpoints").and_then(|v| v.as_array()) {
396 Some(endpoints) => endpoints,
397 None => {
398 tracing::warn!(
399 "Unexpected response format from Databricks API: missing 'endpoints' array"
400 );
401 return Ok(None);
402 }
403 };
404
405 let models: Vec<String> = endpoints
406 .iter()
407 .filter_map(|endpoint| {
408 endpoint
409 .get("name")
410 .and_then(|v| v.as_str())
411 .map(|name| name.to_string())
412 })
413 .collect();
414
415 if models.is_empty() {
416 Ok(None)
417 } else {
418 Ok(Some(models))
419 }
420 }
421}
422
423#[async_trait]
424impl EmbeddingCapable for DatabricksProvider {
425 async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
426 if texts.is_empty() {
427 return Ok(vec![]);
428 }
429
430 let request = json!({
431 "input": texts,
432 });
433
434 let response = self.with_retry(|| self.post(request.clone(), None)).await?;
435
436 let embeddings = response["data"]
437 .as_array()
438 .ok_or_else(|| anyhow::anyhow!("Invalid response format: missing data array"))?
439 .iter()
440 .map(|item| {
441 item["embedding"]
442 .as_array()
443 .ok_or_else(|| anyhow::anyhow!("Invalid embedding format"))?
444 .iter()
445 .map(|v| v.as_f64().map(|f| f as f32))
446 .collect::<Option<Vec<f32>>>()
447 .ok_or_else(|| anyhow::anyhow!("Invalid embedding values"))
448 })
449 .collect::<Result<Vec<Vec<f32>>>>()?;
450
451 Ok(embeddings)
452 }
453}