Skip to main content

aster/providers/
databricks.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::time::Duration;
6
7use super::api_client::{ApiClient, AuthMethod, AuthProvider};
8use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage};
9use super::embedding::EmbeddingCapable;
10use super::errors::ProviderError;
11use super::formats::databricks::{create_request, response_to_message};
12use super::oauth;
13use super::retry::ProviderRetry;
14use super::utils::{
15    get_model, handle_response_openai_compat, map_http_error_to_provider_error,
16    stream_openai_compat, ImageFormat, RequestLog,
17};
18use crate::config::ConfigError;
19use crate::conversation::message::Message;
20use crate::model::ModelConfig;
21use crate::providers::formats::openai::get_usage;
22use crate::providers::retry::{
23    RetryConfig, DEFAULT_BACKOFF_MULTIPLIER, DEFAULT_INITIAL_RETRY_INTERVAL_MS,
24    DEFAULT_MAX_RETRIES, DEFAULT_MAX_RETRY_INTERVAL_MS,
25};
26use rmcp::model::Tool;
27use serde_json::json;
28
29const DEFAULT_CLIENT_ID: &str = "databricks-cli";
30const DEFAULT_REDIRECT_URL: &str = "http://localhost";
31const DEFAULT_SCOPES: &[&str] = &["all-apis", "offline_access"];
32const DEFAULT_TIMEOUT_SECS: u64 = 600;
33
34pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-sonnet-4";
35const DATABRICKS_DEFAULT_FAST_MODEL: &str = "gemini-2-5-flash";
36pub const DATABRICKS_KNOWN_MODELS: &[&str] = &[
37    "databricks-claude-sonnet-4-5",
38    "databricks-claude-3-7-sonnet",
39    "databricks-meta-llama-3-3-70b-instruct",
40    "databricks-meta-llama-3-1-405b-instruct",
41    "databricks-dbrx-instruct",
42];
43
44pub const DATABRICKS_DOC_URL: &str =
45    "https://docs.databricks.com/en/generative-ai/external-models/index.html";
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub enum DatabricksAuth {
49    Token(String),
50    OAuth {
51        host: String,
52        client_id: String,
53        redirect_url: String,
54        scopes: Vec<String>,
55    },
56}
57
58impl DatabricksAuth {
59    pub fn oauth(host: String) -> Self {
60        Self::OAuth {
61            host,
62            client_id: DEFAULT_CLIENT_ID.to_string(),
63            redirect_url: DEFAULT_REDIRECT_URL.to_string(),
64            scopes: DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect(),
65        }
66    }
67
68    pub fn token(token: String) -> Self {
69        Self::Token(token)
70    }
71}
72
73struct DatabricksAuthProvider {
74    auth: DatabricksAuth,
75}
76
77#[async_trait]
78impl AuthProvider for DatabricksAuthProvider {
79    async fn get_auth_header(&self) -> Result<(String, String)> {
80        let token = match &self.auth {
81            DatabricksAuth::Token(token) => token.clone(),
82            DatabricksAuth::OAuth {
83                host,
84                client_id,
85                redirect_url,
86                scopes,
87            } => oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?,
88        };
89        Ok(("Authorization".to_string(), format!("Bearer {}", token)))
90    }
91}
92
93#[derive(Debug, serde::Serialize)]
94pub struct DatabricksProvider {
95    #[serde(skip)]
96    api_client: ApiClient,
97    auth: DatabricksAuth,
98    model: ModelConfig,
99    image_format: ImageFormat,
100    #[serde(skip)]
101    retry_config: RetryConfig,
102    #[serde(skip)]
103    name: String,
104}
105
106impl DatabricksProvider {
107    pub async fn from_env(model: ModelConfig) -> Result<Self> {
108        let config = crate::config::Config::global();
109
110        let mut host: Result<String, ConfigError> = config.get_param("DATABRICKS_HOST");
111        if host.is_err() {
112            host = config.get_secret("DATABRICKS_HOST")
113        }
114
115        if host.is_err() {
116            return Err(ConfigError::NotFound(
117                "Did not find DATABRICKS_HOST in either config file or keyring".to_string(),
118            )
119            .into());
120        }
121
122        let host = host?;
123        let retry_config = Self::load_retry_config(config);
124
125        let auth = if let Ok(api_key) = config.get_secret("DATABRICKS_TOKEN") {
126            DatabricksAuth::token(api_key)
127        } else {
128            DatabricksAuth::oauth(host.clone())
129        };
130
131        let auth_method =
132            AuthMethod::Custom(Box::new(DatabricksAuthProvider { auth: auth.clone() }));
133
134        let api_client =
135            ApiClient::with_timeout(host, auth_method, Duration::from_secs(DEFAULT_TIMEOUT_SECS))?;
136
137        // Create the provider without the fast model first
138        let mut provider = Self {
139            api_client,
140            auth,
141            model: model.clone(),
142            image_format: ImageFormat::OpenAi,
143            retry_config,
144            name: Self::metadata().name,
145        };
146
147        // Check if the default fast model exists in the workspace
148        let model_with_fast = if let Ok(Some(models)) = provider.fetch_supported_models().await {
149            if models.contains(&DATABRICKS_DEFAULT_FAST_MODEL.to_string()) {
150                tracing::debug!(
151                    "Found {} in Databricks workspace, setting as fast model",
152                    DATABRICKS_DEFAULT_FAST_MODEL
153                );
154                model.with_fast(DATABRICKS_DEFAULT_FAST_MODEL.to_string())
155            } else {
156                tracing::debug!(
157                    "{} not found in Databricks workspace, not setting fast model",
158                    DATABRICKS_DEFAULT_FAST_MODEL
159                );
160                model
161            }
162        } else {
163            tracing::debug!("Could not fetch Databricks models, not setting fast model");
164            model
165        };
166
167        provider.model = model_with_fast;
168        Ok(provider)
169    }
170
171    fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
172        let max_retries = config
173            .get_param("DATABRICKS_MAX_RETRIES")
174            .ok()
175            .and_then(|v: String| v.parse::<usize>().ok())
176            .unwrap_or(DEFAULT_MAX_RETRIES);
177
178        let initial_interval_ms = config
179            .get_param("DATABRICKS_INITIAL_RETRY_INTERVAL_MS")
180            .ok()
181            .and_then(|v: String| v.parse::<u64>().ok())
182            .unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS);
183
184        let backoff_multiplier = config
185            .get_param("DATABRICKS_BACKOFF_MULTIPLIER")
186            .ok()
187            .and_then(|v: String| v.parse::<f64>().ok())
188            .unwrap_or(DEFAULT_BACKOFF_MULTIPLIER);
189
190        let max_interval_ms = config
191            .get_param("DATABRICKS_MAX_RETRY_INTERVAL_MS")
192            .ok()
193            .and_then(|v: String| v.parse::<u64>().ok())
194            .unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS);
195
196        RetryConfig {
197            max_retries,
198            initial_interval_ms,
199            backoff_multiplier,
200            max_interval_ms,
201        }
202    }
203
204    pub fn from_params(host: String, api_key: String, model: ModelConfig) -> Result<Self> {
205        let auth = DatabricksAuth::token(api_key);
206        let auth_method =
207            AuthMethod::Custom(Box::new(DatabricksAuthProvider { auth: auth.clone() }));
208
209        let api_client = ApiClient::with_timeout(host, auth_method, Duration::from_secs(600))?;
210
211        Ok(Self {
212            api_client,
213            auth,
214            model,
215            image_format: ImageFormat::OpenAi,
216            retry_config: RetryConfig::default(),
217            name: Self::metadata().name,
218        })
219    }
220
221    fn get_endpoint_path(&self, model_name: &str, is_embedding: bool) -> String {
222        if is_embedding {
223            "serving-endpoints/text-embedding-3-small/invocations".to_string()
224        } else {
225            format!("serving-endpoints/{}/invocations", model_name)
226        }
227    }
228
229    async fn post(&self, payload: Value, model_name: Option<&str>) -> Result<Value, ProviderError> {
230        let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none();
231        let model_to_use = model_name.unwrap_or(&self.model.model_name);
232        let path = self.get_endpoint_path(model_to_use, is_embedding);
233
234        let response = self.api_client.response_post(&path, &payload).await?;
235        handle_response_openai_compat(response).await
236    }
237}
238
239#[async_trait]
240impl Provider for DatabricksProvider {
241    fn metadata() -> ProviderMetadata {
242        ProviderMetadata::new(
243            "databricks",
244            "Databricks",
245            "Models on Databricks AI Gateway",
246            DATABRICKS_DEFAULT_MODEL,
247            DATABRICKS_KNOWN_MODELS.to_vec(),
248            DATABRICKS_DOC_URL,
249            vec![
250                ConfigKey::new("DATABRICKS_HOST", true, false, None),
251                ConfigKey::new("DATABRICKS_TOKEN", false, true, None),
252            ],
253        )
254    }
255
256    fn get_name(&self) -> &str {
257        &self.name
258    }
259
260    fn retry_config(&self) -> RetryConfig {
261        self.retry_config.clone()
262    }
263
264    fn get_model_config(&self) -> ModelConfig {
265        self.model.clone()
266    }
267
268    #[tracing::instrument(
269        skip(self, model_config, system, messages, tools),
270        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
271    )]
272    async fn complete_with_model(
273        &self,
274        model_config: &ModelConfig,
275        system: &str,
276        messages: &[Message],
277        tools: &[Tool],
278    ) -> Result<(Message, ProviderUsage), ProviderError> {
279        let mut payload =
280            create_request(model_config, system, messages, tools, &self.image_format)?;
281        payload
282            .as_object_mut()
283            .expect("payload should have model key")
284            .remove("model");
285
286        let mut log = RequestLog::start(&self.model, &payload)?;
287
288        let response = self
289            .with_retry(|| self.post(payload.clone(), Some(&model_config.model_name)))
290            .await?;
291
292        let message = response_to_message(&response)?;
293        let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
294            tracing::debug!("Failed to get usage data");
295            Usage::default()
296        });
297        let response_model = get_model(&response);
298        log.write(&response, Some(&usage))?;
299
300        Ok((message, ProviderUsage::new(response_model, usage)))
301    }
302
303    async fn stream(
304        &self,
305        system: &str,
306        messages: &[Message],
307        tools: &[Tool],
308    ) -> Result<MessageStream, ProviderError> {
309        let model_config = self.model.clone();
310
311        let mut payload =
312            create_request(&model_config, system, messages, tools, &self.image_format)?;
313        payload
314            .as_object_mut()
315            .expect("payload should have model key")
316            .remove("model");
317
318        payload
319            .as_object_mut()
320            .unwrap()
321            .insert("stream".to_string(), Value::Bool(true));
322
323        let path = self.get_endpoint_path(&model_config.model_name, false);
324        let mut log = RequestLog::start(&self.model, &payload)?;
325        let response = self
326            .with_retry(|| async {
327                let resp = self.api_client.response_post(&path, &payload).await?;
328                if !resp.status().is_success() {
329                    let status = resp.status();
330                    let error_text = resp.text().await.unwrap_or_default();
331
332                    // Parse as JSON if possible to pass to map_http_error_to_provider_error
333                    let json_payload = serde_json::from_str::<Value>(&error_text).ok();
334                    return Err(map_http_error_to_provider_error(status, json_payload));
335                }
336                Ok(resp)
337            })
338            .await
339            .inspect_err(|e| {
340                let _ = log.error(e);
341            })?;
342
343        stream_openai_compat(response, log)
344    }
345
346    fn supports_streaming(&self) -> bool {
347        true
348    }
349
350    fn supports_embeddings(&self) -> bool {
351        true
352    }
353
354    async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
355        EmbeddingCapable::create_embeddings(self, texts)
356            .await
357            .map_err(|e| ProviderError::ExecutionError(e.to_string()))
358    }
359
360    async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
361        let response = match self
362            .api_client
363            .response_get("api/2.0/serving-endpoints")
364            .await
365        {
366            Ok(resp) => resp,
367            Err(e) => {
368                tracing::warn!("Failed to fetch Databricks models: {}", e);
369                return Ok(None);
370            }
371        };
372
373        if !response.status().is_success() {
374            let status = response.status();
375            if let Ok(error_text) = response.text().await {
376                tracing::warn!(
377                    "Failed to fetch Databricks models: {} - {}",
378                    status,
379                    error_text
380                );
381            } else {
382                tracing::warn!("Failed to fetch Databricks models: {}", status);
383            }
384            return Ok(None);
385        }
386
387        let json: Value = match response.json().await {
388            Ok(json) => json,
389            Err(e) => {
390                tracing::warn!("Failed to parse Databricks API response: {}", e);
391                return Ok(None);
392            }
393        };
394
395        let endpoints = match json.get("endpoints").and_then(|v| v.as_array()) {
396            Some(endpoints) => endpoints,
397            None => {
398                tracing::warn!(
399                    "Unexpected response format from Databricks API: missing 'endpoints' array"
400                );
401                return Ok(None);
402            }
403        };
404
405        let models: Vec<String> = endpoints
406            .iter()
407            .filter_map(|endpoint| {
408                endpoint
409                    .get("name")
410                    .and_then(|v| v.as_str())
411                    .map(|name| name.to_string())
412            })
413            .collect();
414
415        if models.is_empty() {
416            Ok(None)
417        } else {
418            Ok(Some(models))
419        }
420    }
421}
422
423#[async_trait]
424impl EmbeddingCapable for DatabricksProvider {
425    async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
426        if texts.is_empty() {
427            return Ok(vec![]);
428        }
429
430        let request = json!({
431            "input": texts,
432        });
433
434        let response = self.with_retry(|| self.post(request.clone(), None)).await?;
435
436        let embeddings = response["data"]
437            .as_array()
438            .ok_or_else(|| anyhow::anyhow!("Invalid response format: missing data array"))?
439            .iter()
440            .map(|item| {
441                item["embedding"]
442                    .as_array()
443                    .ok_or_else(|| anyhow::anyhow!("Invalid embedding format"))?
444                    .iter()
445                    .map(|v| v.as_f64().map(|f| f as f32))
446                    .collect::<Option<Vec<f32>>>()
447                    .ok_or_else(|| anyhow::anyhow!("Invalid embedding values"))
448            })
449            .collect::<Result<Vec<Vec<f32>>>>()?;
450
451        Ok(embeddings)
452    }
453}