rig/providers/together/
client.rs1use crate::{
2 client::{
3 self, BearerAuth, Capabilities, Capable, Nothing, Provider, ProviderBuilder, ProviderClient,
4 },
5 http_client,
6};
7
8const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
12
13#[derive(Debug, Default, Clone, Copy)]
14pub struct TogetherExt;
15#[derive(Debug, Default, Clone, Copy)]
16pub struct TogetherExtBuilder;
17
18type TogetherApiKey = BearerAuth;
19
20pub type Client<H = reqwest::Client> = client::Client<TogetherExt, H>;
21pub type ClientBuilder<H = reqwest::Client> =
22 client::ClientBuilder<TogetherExtBuilder, TogetherApiKey, H>;
23
24impl Provider for TogetherExt {
25 type Builder = TogetherExtBuilder;
26
27 const VERIFY_PATH: &'static str = "/models";
28}
29
30impl<H> Capabilities<H> for TogetherExt {
31 type Completion = Capable<super::CompletionModel<H>>;
32 type Embeddings = Capable<super::EmbeddingModel<H>>;
33
34 type Transcription = Nothing;
35 type ModelListing = Nothing;
36 #[cfg(feature = "image")]
37 type ImageGeneration = Nothing;
38 #[cfg(feature = "audio")]
39 type AudioGeneration = Nothing;
40}
41
42impl ProviderBuilder for TogetherExtBuilder {
43 type Extension<H>
44 = TogetherExt
45 where
46 H: http_client::HttpClientExt;
47 type ApiKey = TogetherApiKey;
48
49 const BASE_URL: &'static str = TOGETHER_AI_BASE_URL;
50
51 fn build<H>(
52 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
53 ) -> http_client::Result<Self::Extension<H>>
54 where
55 H: http_client::HttpClientExt,
56 {
57 Ok(TogetherExt)
58 }
59}
60
61impl ProviderClient for Client {
62 type Input = String;
63
64 fn from_env() -> Self {
67 let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
68 Self::new(&api_key).unwrap()
69 }
70
71 fn from_val(input: Self::Input) -> Self {
72 Self::new(&input).unwrap()
73 }
74}
75
76pub mod together_ai_api_types {
77 use serde::Deserialize;
78
79 impl ApiErrorResponse {
80 pub fn message(&self) -> String {
81 format!("Code `{}`: {}", self.code, self.error)
82 }
83 }
84
85 #[derive(Debug, Deserialize)]
86 pub struct ApiErrorResponse {
87 pub error: String,
88 pub code: String,
89 }
90
91 #[derive(Debug, Deserialize)]
92 #[serde(untagged)]
93 pub enum ApiResponse<T> {
94 Ok(T),
95 Error(ApiErrorResponse),
96 }
97}
98#[cfg(test)]
99mod tests {
100 #[test]
101 fn test_client_initialization() {
102 let _client =
103 crate::providers::together::Client::new("dummy-key").expect("Client::new() failed");
104 let _client_from_builder = crate::providers::together::Client::builder()
105 .api_key("dummy-key")
106 .build()
107 .expect("Client::builder() failed");
108 }
109}