Skip to main content

rig_dyn/
traits.rs

1use async_trait::async_trait;
2use rig::{
3    OneOrMany,
4    completion::{CompletionError, CompletionRequest},
5    embeddings::{self, Embedding, EmbeddingError},
6    message::AssistantContent,
7};
8
9#[async_trait]
10pub trait EmbeddingModel: Send + Sync {
11    async fn embed_text(&self, input: &str) -> Result<Embedding, EmbeddingError>;
12    async fn embed_texts(&self, input: Vec<String>) -> Result<Vec<Embedding>, EmbeddingError>;
13    fn ndims(&self) -> usize;
14}
15
16#[async_trait]
17impl<T> EmbeddingModel for T
18where
19    T: embeddings::EmbeddingModel + Send + Sync,
20{
21    async fn embed_text(&self, input: &str) -> Result<Embedding, EmbeddingError> {
22        embeddings::EmbeddingModel::embed_text(self, input).await
23    }
24
25    async fn embed_texts(&self, input: Vec<String>) -> Result<Vec<Embedding>, EmbeddingError> {
26        embeddings::EmbeddingModel::embed_texts(self, input).await
27    }
28
29    fn ndims(&self) -> usize {
30        embeddings::EmbeddingModel::ndims(self)
31    }
32}
33
34#[async_trait]
35pub trait CompletionModel: Send + Sync {
36    async fn completion(
37        &self,
38        completion: CompletionRequest,
39    ) -> Result<OneOrMany<AssistantContent>, CompletionError>;
40}
41
42#[async_trait]
43impl<M> CompletionModel for M
44where
45    M: rig::completion::CompletionModel + Send + Sync,
46{
47    async fn completion(
48        &self,
49        request: CompletionRequest,
50    ) -> Result<OneOrMany<AssistantContent>, CompletionError> {
51        Ok(self.completion(request).await?.choice)
52    }
53}