golem_ai_embed_cohere/
lib.rs1use client::EmbeddingsApi;
2use conversions::create_embed_request;
3use golem_ai_embed::{
4 durability::{DurableEmbed, ExtendedEmbeddingProvider},
5 model::{
6 Config, ContentPart, EmbeddingResponse as GolemEmbeddingResponse, Error, RerankResponse,
7 },
8 EmbeddingProvider, LOGGING_STATE,
9};
10
11use crate::conversions::{
12 create_rerank_request, process_embedding_response, process_rerank_response,
13};
14
15mod client;
16pub mod config;
17mod conversions;
18
19pub use config::CohereConfig;
20#[cfg(feature = "golem")]
21pub use config::CohereHostConfig;
22
23pub struct Cohere;
24
25impl Cohere {
26 fn embeddings(
27 client: EmbeddingsApi,
28 inputs: Vec<ContentPart>,
29 config: Config,
30 ) -> Result<GolemEmbeddingResponse, Error> {
31 let request = create_embed_request(inputs, config.clone());
32 match request {
33 Ok(request) => match client.generate_embeding(request) {
34 Ok(response) => process_embedding_response(response, config),
35 Err(err) => Err(err),
36 },
37 Err(err) => Err(err),
38 }
39 }
40
41 fn rerank(
42 client: EmbeddingsApi,
43 query: String,
44 documents: Vec<String>,
45 config: Config,
46 ) -> Result<RerankResponse, Error> {
47 let request = create_rerank_request(query, documents, config.clone());
48 match request {
49 Ok(request) => match client.rerank(request) {
50 Ok(response) => process_rerank_response(response, config),
51 Err(err) => Err(err),
52 },
53 Err(err) => Err(err),
54 }
55 }
56}
57
58impl EmbeddingProvider for Cohere {
59 type ProviderConfig = CohereConfig;
60
61 fn generate(
62 provider_config: Self::ProviderConfig,
63 inputs: Vec<ContentPart>,
64 config: Config,
65 ) -> Result<GolemEmbeddingResponse, Error> {
66 LOGGING_STATE.with_borrow_mut(|state| state.init());
67 let client = EmbeddingsApi::new(&provider_config);
68 Self::embeddings(client, inputs, config)
69 }
70
71 fn rerank(
72 provider_config: Self::ProviderConfig,
73 query: String,
74 documents: Vec<String>,
75 config: Config,
76 ) -> Result<RerankResponse, Error> {
77 LOGGING_STATE.with_borrow_mut(|state| state.init());
78 let client = EmbeddingsApi::new(&provider_config);
79 Self::rerank(client, query, documents, config)
80 }
81}
82
83impl ExtendedEmbeddingProvider for Cohere {}
84
85pub type DurableCohere = DurableEmbed<Cohere>;