rig/providers/together/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, Nothing, Provider, ProviderBuilder, ProviderClient,
4    },
5    http_client,
6};
7
8// ================================================================
9// Together AI Client
10// ================================================================
11const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
12
13#[derive(Debug, Default, Clone, Copy)]
14pub struct TogetherExt;
15#[derive(Debug, Default, Clone, Copy)]
16pub struct TogetherExtBuilder;
17
18type TogetherApiKey = BearerAuth;
19
20pub type Client<H = reqwest::Client> = client::Client<TogetherExt, H>;
21pub type ClientBuilder<H = reqwest::Client> =
22    client::ClientBuilder<TogetherExtBuilder, TogetherApiKey, H>;
23
24impl Provider for TogetherExt {
25    type Builder = TogetherExtBuilder;
26
27    const VERIFY_PATH: &'static str = "/models";
28
29    fn build<H>(
30        _: &client::ClientBuilder<Self::Builder, TogetherApiKey, H>,
31    ) -> http_client::Result<Self> {
32        Ok(Self)
33    }
34}
35
36impl<H> Capabilities<H> for TogetherExt {
37    type Completion = Capable<super::CompletionModel<H>>;
38    type Embeddings = Capable<super::EmbeddingModel<H>>;
39
40    type Transcription = Nothing;
41    #[cfg(feature = "image")]
42    type ImageGeneration = Nothing;
43    #[cfg(feature = "audio")]
44    type AudioGeneration = Nothing;
45}
46
47impl ProviderBuilder for TogetherExtBuilder {
48    type Output = TogetherExt;
49    type ApiKey = TogetherApiKey;
50
51    const BASE_URL: &'static str = TOGETHER_AI_BASE_URL;
52}
53
54impl ProviderClient for Client {
55    type Input = String;
56
57    /// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable.
58    /// Panics if the environment variable is not set.
59    fn from_env() -> Self {
60        let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
61        Self::new(&api_key).unwrap()
62    }
63
64    fn from_val(input: Self::Input) -> Self {
65        Self::new(&input).unwrap()
66    }
67}
68
69pub mod together_ai_api_types {
70    use serde::Deserialize;
71
72    impl ApiErrorResponse {
73        pub fn message(&self) -> String {
74            format!("Code `{}`: {}", self.code, self.error)
75        }
76    }
77
78    #[derive(Debug, Deserialize)]
79    pub struct ApiErrorResponse {
80        pub error: String,
81        pub code: String,
82    }
83
84    #[derive(Debug, Deserialize)]
85    #[serde(untagged)]
86    pub enum ApiResponse<T> {
87        Ok(T),
88        Error(ApiErrorResponse),
89    }
90}