1use async_trait::async_trait;
2use rig::{
3 OneOrMany,
4 completion::{CompletionError, CompletionRequest},
5 embeddings::{self, Embedding, EmbeddingError},
6 message::AssistantContent,
7};
8
9#[async_trait]
10pub trait EmbeddingModel: Send + Sync {
11 async fn embed_text(&self, input: &str) -> Result<Embedding, EmbeddingError>;
12 async fn embed_texts(&self, input: Vec<String>) -> Result<Vec<Embedding>, EmbeddingError>;
13 fn ndims(&self) -> usize;
14}
15
16#[async_trait]
17impl<T> EmbeddingModel for T
18where
19 T: embeddings::EmbeddingModel + Send + Sync,
20{
21 async fn embed_text(&self, input: &str) -> Result<Embedding, EmbeddingError> {
22 embeddings::EmbeddingModel::embed_text(self, input).await
23 }
24
25 async fn embed_texts(&self, input: Vec<String>) -> Result<Vec<Embedding>, EmbeddingError> {
26 embeddings::EmbeddingModel::embed_texts(self, input).await
27 }
28
29 fn ndims(&self) -> usize {
30 embeddings::EmbeddingModel::ndims(self)
31 }
32}
33
34#[async_trait]
35pub trait CompletionModel: Send + Sync {
36 async fn completion(
37 &self,
38 completion: CompletionRequest,
39 ) -> Result<OneOrMany<AssistantContent>, CompletionError>;
40}
41
42#[async_trait]
43impl<M> CompletionModel for M
44where
45 M: rig::completion::CompletionModel + Send + Sync,
46{
47 async fn completion(
48 &self,
49 request: CompletionRequest,
50 ) -> Result<OneOrMany<AssistantContent>, CompletionError> {
51 Ok(self.completion(request).await?.choice)
52 }
53}