Skip to main content

rig_dyn/
traits.rs

1use async_trait::async_trait;
2use rig::{
3    client::FinalCompletionResponse,
4    completion::{self, CompletionError, CompletionRequest, CompletionResponse, GetTokenUsage},
5    embeddings::{self, Embedding, EmbeddingError},
6    streaming::StreamingCompletionResponse,
7};
8use std::sync::Arc;
9use embeddings::EmbeddingModel;
10use rig::wasm_compat::WasmCompatSend;
11
12#[async_trait]
13pub trait DynEmbeddingModel: Send + Sync {
14    async fn embed_text(&self, input: &str) -> Result<Embedding, EmbeddingError>;
15    async fn embed_texts(&self, input: Vec<String>) -> Result<Vec<Embedding>, EmbeddingError>;
16    fn ndims(&self) -> usize;
17}
18
19#[derive(Clone)]
20#[allow(dead_code)]
21pub struct RigEmbeddingModelAdapter {
22    inner: Arc<dyn DynEmbeddingModel>,
23}
24
25impl RigEmbeddingModelAdapter {
26    #[allow(dead_code)]
27    pub fn new(inner: Arc<dyn DynEmbeddingModel>) -> Self {
28        Self { inner }
29    }
30}
31
32impl From<Box<dyn DynEmbeddingModel>> for RigEmbeddingModelAdapter {
33    fn from(value: Box<dyn DynEmbeddingModel>) -> Self {
34        Self {
35            inner: Arc::from(value),
36        }
37    }
38}
39
40impl From<Arc<dyn DynEmbeddingModel>> for RigEmbeddingModelAdapter {
41    fn from(value: Arc<dyn DynEmbeddingModel>) -> Self {
42        Self { inner: value }
43    }
44}
45
46impl EmbeddingModel for RigEmbeddingModelAdapter {
47    const MAX_DOCUMENTS: usize = 1000;
48    type Client = ();
49
50
51    fn make(_client: &Self::Client, _model: impl Into<String>, _dims: Option<usize>) -> Self {
52        panic!("make() is not supported by rig_dyn::EmbeddingModel adapter");
53    }
54
55    fn ndims(&self) -> usize {
56        self.inner.ndims()
57    }
58
59    async fn embed_texts(&self, texts: impl IntoIterator<Item = String> + WasmCompatSend,) -> Result<Vec<Embedding>, EmbeddingError> {
60        let texts_vec: Vec<String> = texts.into_iter().collect();
61        self.inner.embed_texts(texts_vec).await
62    }
63
64    async fn embed_text(&self, input: &str) -> Result<Embedding, EmbeddingError> {
65        self.inner.embed_text(input).await
66    }
67}
68
69#[async_trait]
70impl<T> DynEmbeddingModel for T
71where
72    T: EmbeddingModel + Send + Sync,
73{
74    async fn embed_text(&self, input: &str) -> Result<Embedding, EmbeddingError> {
75        EmbeddingModel::embed_text(self, input).await
76    }
77
78    async fn embed_texts(&self, input: Vec<String>) -> Result<Vec<Embedding>, EmbeddingError> {
79        EmbeddingModel::embed_texts(self, input).await
80    }
81
82    fn ndims(&self) -> usize {
83        EmbeddingModel::ndims(self)
84    }
85}
86
87#[async_trait]
88pub trait CompletionModel: Send + Sync {
89    async fn completion(
90        &self,
91        request: CompletionRequest,
92    ) -> Result<CompletionResponse<()>, CompletionError>;
93}
94
95#[derive(Clone)]
96pub struct RigCompletionModelAdapter {
97    inner: Arc<dyn CompletionModel>,
98}
99
100impl RigCompletionModelAdapter {
101    pub fn new(inner: Arc<dyn CompletionModel>) -> Self {
102        Self { inner }
103    }
104}
105
106impl From<Box<dyn CompletionModel>> for RigCompletionModelAdapter {
107    fn from(value: Box<dyn CompletionModel>) -> Self {
108        Self {
109            inner: Arc::from(value),
110        }
111    }
112}
113
114impl From<Arc<dyn CompletionModel>> for RigCompletionModelAdapter {
115    fn from(value: Arc<dyn CompletionModel>) -> Self {
116        Self { inner: value }
117    }
118}
119
120impl completion::CompletionModel for RigCompletionModelAdapter {
121    type Response = ();
122    type StreamingResponse = FinalCompletionResponse;
123    type Client = Arc<dyn CompletionModel>;
124
125    fn make(client: &Self::Client, _model: impl Into<String>) -> Self {
126        Self {
127            inner: client.clone(),
128        }
129    }
130
131    fn completion(
132        &self,
133        request: CompletionRequest,
134    ) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
135           + rig::wasm_compat::WasmCompatSend {
136        let model = self.inner.clone();
137        async move { model.completion(request).await }
138    }
139
140    fn stream(
141        &self,
142        _request: CompletionRequest,
143    ) -> impl std::future::Future<
144        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
145    > + rig::wasm_compat::WasmCompatSend {
146        async {
147            Err(CompletionError::ResponseError(
148                "Streaming is not supported by rig_dyn::CompletionModel adapter".to_string(),
149            ))
150        }
151    }
152}
153
154#[async_trait]
155impl<M> CompletionModel for M
156where
157    M: completion::CompletionModel + Send + Sync,
158    M::StreamingResponse: Clone + Unpin + GetTokenUsage + 'static,
159{
160    async fn completion(
161        &self,
162        request: CompletionRequest,
163    ) -> Result<CompletionResponse<()>, CompletionError> {
164        self.completion(request).await.map(|response| CompletionResponse {
165            choice: response.choice,
166            usage: response.usage,
167            raw_response: (),
168            message_id: response.message_id,
169        })
170    }
171}