Skip to main content

golem_ai_embed_cohere/
lib.rs

1use client::EmbeddingsApi;
2use conversions::create_embed_request;
3use golem_ai_embed::{
4    durability::{DurableEmbed, ExtendedEmbeddingProvider},
5    model::{
6        Config, ContentPart, EmbeddingResponse as GolemEmbeddingResponse, Error, RerankResponse,
7    },
8    EmbeddingProvider, LOGGING_STATE,
9};
10
11use crate::conversions::{
12    create_rerank_request, process_embedding_response, process_rerank_response,
13};
14
15mod client;
16pub mod config;
17mod conversions;
18
19pub use config::CohereConfig;
20#[cfg(feature = "golem")]
21pub use config::CohereHostConfig;
22
23pub struct Cohere;
24
25impl Cohere {
26    fn embeddings(
27        client: EmbeddingsApi,
28        inputs: Vec<ContentPart>,
29        config: Config,
30    ) -> Result<GolemEmbeddingResponse, Error> {
31        let request = create_embed_request(inputs, config.clone());
32        match request {
33            Ok(request) => match client.generate_embeding(request) {
34                Ok(response) => process_embedding_response(response, config),
35                Err(err) => Err(err),
36            },
37            Err(err) => Err(err),
38        }
39    }
40
41    fn rerank(
42        client: EmbeddingsApi,
43        query: String,
44        documents: Vec<String>,
45        config: Config,
46    ) -> Result<RerankResponse, Error> {
47        let request = create_rerank_request(query, documents, config.clone());
48        match request {
49            Ok(request) => match client.rerank(request) {
50                Ok(response) => process_rerank_response(response, config),
51                Err(err) => Err(err),
52            },
53            Err(err) => Err(err),
54        }
55    }
56}
57
58impl EmbeddingProvider for Cohere {
59    type ProviderConfig = CohereConfig;
60
61    fn generate(
62        provider_config: Self::ProviderConfig,
63        inputs: Vec<ContentPart>,
64        config: Config,
65    ) -> Result<GolemEmbeddingResponse, Error> {
66        LOGGING_STATE.with_borrow_mut(|state| state.init());
67        let client = EmbeddingsApi::new(&provider_config);
68        Self::embeddings(client, inputs, config)
69    }
70
71    fn rerank(
72        provider_config: Self::ProviderConfig,
73        query: String,
74        documents: Vec<String>,
75        config: Config,
76    ) -> Result<RerankResponse, Error> {
77        LOGGING_STATE.with_borrow_mut(|state| state.init());
78        let client = EmbeddingsApi::new(&provider_config);
79        Self::rerank(client, query, documents, config)
80    }
81}
82
83impl ExtendedEmbeddingProvider for Cohere {}
84
85pub type DurableCohere = DurableEmbed<Cohere>;