1use async_trait::async_trait;
2use rig::{
3 client::FinalCompletionResponse,
4 completion::{self, CompletionError, CompletionRequest, CompletionResponse, GetTokenUsage},
5 embeddings::{self, Embedding, EmbeddingError},
6 streaming::StreamingCompletionResponse,
7};
8use std::sync::Arc;
9use embeddings::EmbeddingModel;
10use rig::wasm_compat::WasmCompatSend;
11
12#[async_trait]
13pub trait DynEmbeddingModel: Send + Sync {
14 async fn embed_text(&self, input: &str) -> Result<Embedding, EmbeddingError>;
15 async fn embed_texts(&self, input: Vec<String>) -> Result<Vec<Embedding>, EmbeddingError>;
16 fn ndims(&self) -> usize;
17}
18
19#[derive(Clone)]
20#[allow(dead_code)]
21pub struct RigEmbeddingModelAdapter {
22 inner: Arc<dyn DynEmbeddingModel>,
23}
24
25impl RigEmbeddingModelAdapter {
26 #[allow(dead_code)]
27 pub fn new(inner: Arc<dyn DynEmbeddingModel>) -> Self {
28 Self { inner }
29 }
30}
31
32impl From<Box<dyn DynEmbeddingModel>> for RigEmbeddingModelAdapter {
33 fn from(value: Box<dyn DynEmbeddingModel>) -> Self {
34 Self {
35 inner: Arc::from(value),
36 }
37 }
38}
39
40impl From<Arc<dyn DynEmbeddingModel>> for RigEmbeddingModelAdapter {
41 fn from(value: Arc<dyn DynEmbeddingModel>) -> Self {
42 Self { inner: value }
43 }
44}
45
46impl EmbeddingModel for RigEmbeddingModelAdapter {
47 const MAX_DOCUMENTS: usize = 1000;
48 type Client = ();
49
50
51 fn make(_client: &Self::Client, _model: impl Into<String>, _dims: Option<usize>) -> Self {
52 panic!("make() is not supported by rig_dyn::EmbeddingModel adapter");
53 }
54
55 fn ndims(&self) -> usize {
56 self.inner.ndims()
57 }
58
59 async fn embed_texts(&self, texts: impl IntoIterator<Item = String> + WasmCompatSend,) -> Result<Vec<Embedding>, EmbeddingError> {
60 let texts_vec: Vec<String> = texts.into_iter().collect();
61 self.inner.embed_texts(texts_vec).await
62 }
63
64 async fn embed_text(&self, input: &str) -> Result<Embedding, EmbeddingError> {
65 self.inner.embed_text(input).await
66 }
67}
68
69#[async_trait]
70impl<T> DynEmbeddingModel for T
71where
72 T: EmbeddingModel + Send + Sync,
73{
74 async fn embed_text(&self, input: &str) -> Result<Embedding, EmbeddingError> {
75 EmbeddingModel::embed_text(self, input).await
76 }
77
78 async fn embed_texts(&self, input: Vec<String>) -> Result<Vec<Embedding>, EmbeddingError> {
79 EmbeddingModel::embed_texts(self, input).await
80 }
81
82 fn ndims(&self) -> usize {
83 EmbeddingModel::ndims(self)
84 }
85}
86
87#[async_trait]
88pub trait CompletionModel: Send + Sync {
89 async fn completion(
90 &self,
91 request: CompletionRequest,
92 ) -> Result<CompletionResponse<()>, CompletionError>;
93}
94
95#[derive(Clone)]
96pub struct RigCompletionModelAdapter {
97 inner: Arc<dyn CompletionModel>,
98}
99
100impl RigCompletionModelAdapter {
101 pub fn new(inner: Arc<dyn CompletionModel>) -> Self {
102 Self { inner }
103 }
104}
105
106impl From<Box<dyn CompletionModel>> for RigCompletionModelAdapter {
107 fn from(value: Box<dyn CompletionModel>) -> Self {
108 Self {
109 inner: Arc::from(value),
110 }
111 }
112}
113
114impl From<Arc<dyn CompletionModel>> for RigCompletionModelAdapter {
115 fn from(value: Arc<dyn CompletionModel>) -> Self {
116 Self { inner: value }
117 }
118}
119
120impl completion::CompletionModel for RigCompletionModelAdapter {
121 type Response = ();
122 type StreamingResponse = FinalCompletionResponse;
123 type Client = Arc<dyn CompletionModel>;
124
125 fn make(client: &Self::Client, _model: impl Into<String>) -> Self {
126 Self {
127 inner: client.clone(),
128 }
129 }
130
131 fn completion(
132 &self,
133 request: CompletionRequest,
134 ) -> impl std::future::Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>>
135 + rig::wasm_compat::WasmCompatSend {
136 let model = self.inner.clone();
137 async move { model.completion(request).await }
138 }
139
140 fn stream(
141 &self,
142 _request: CompletionRequest,
143 ) -> impl std::future::Future<
144 Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
145 > + rig::wasm_compat::WasmCompatSend {
146 async {
147 Err(CompletionError::ResponseError(
148 "Streaming is not supported by rig_dyn::CompletionModel adapter".to_string(),
149 ))
150 }
151 }
152}
153
154#[async_trait]
155impl<M> CompletionModel for M
156where
157 M: completion::CompletionModel + Send + Sync,
158 M::StreamingResponse: Clone + Unpin + GetTokenUsage + 'static,
159{
160 async fn completion(
161 &self,
162 request: CompletionRequest,
163 ) -> Result<CompletionResponse<()>, CompletionError> {
164 self.completion(request).await.map(|response| CompletionResponse {
165 choice: response.choice,
166 usage: response.usage,
167 raw_response: (),
168 message_id: response.message_id,
169 })
170 }
171}