Skip to main content

claude_agent/client/adapter/
traits.rs

1//! Provider adapter trait definition.
2
3use std::fmt::Debug;
4
5use async_trait::async_trait;
6
7use super::config::{ModelType, ProviderConfig};
8use crate::client::messages::{CountTokensRequest, CountTokensResponse, CreateMessageRequest};
9use crate::types::ApiResponse;
10use crate::{Error, Result};
11
12#[async_trait]
13pub trait ProviderAdapter: Send + Sync + Debug {
14    fn config(&self) -> &ProviderConfig;
15
16    fn name(&self) -> &'static str;
17
18    /// Returns the base URL for API requests (e.g. `https://api.anthropic.com`).
19    fn base_url(&self) -> &str {
20        "https://api.anthropic.com"
21    }
22
23    fn model(&self, model_type: ModelType) -> &str {
24        self.config().models.get(model_type)
25    }
26
27    async fn build_url(&self, model: &str, stream: bool) -> String;
28
29    async fn prepare_request(&self, request: CreateMessageRequest) -> CreateMessageRequest {
30        request
31    }
32
33    async fn transform_request(&self, request: CreateMessageRequest) -> Result<serde_json::Value>;
34
35    fn transform_response(&self, response: serde_json::Value) -> Result<ApiResponse> {
36        serde_json::from_value(response).map_err(|e| Error::Parse(e.to_string()))
37    }
38
39    async fn apply_auth_headers(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
40        req
41    }
42
43    async fn send(
44        &self,
45        http: &reqwest::Client,
46        request: CreateMessageRequest,
47    ) -> Result<ApiResponse>;
48
49    async fn send_stream(
50        &self,
51        http: &reqwest::Client,
52        request: CreateMessageRequest,
53    ) -> Result<reqwest::Response>;
54
55    fn supports_credential_refresh(&self) -> bool {
56        false
57    }
58
59    async fn ensure_fresh_credentials(&self) -> Result<()> {
60        Ok(())
61    }
62
63    async fn refresh_credentials(&self) -> Result<()> {
64        Ok(())
65    }
66
67    async fn count_tokens(
68        &self,
69        _http: &reqwest::Client,
70        _request: CountTokensRequest,
71    ) -> Result<CountTokensResponse> {
72        Err(Error::NotSupported {
73            provider: self.name(),
74            operation: "count_tokens",
75        })
76    }
77}