rig_gemini_grpc/
client.rs1use rig::prelude::*;
2use std::fmt::Debug;
3use tonic::metadata::MetadataValue;
4use tonic::service::Interceptor;
5use tonic::transport::{Channel, Endpoint};
6use tonic::{Request, Status};
7
8use super::GenerativeServiceClient;
9use crate::completion::CompletionModel;
10use crate::embedding::EmbeddingModel;
11
12const GEMINI_GRPC_ENDPOINT: &str = "https://generativelanguage.googleapis.com";
16
17const RIG_GRPC_CLIENT_IDENTIFIER: &str = "rig-grpc/0.1.0";
19
20#[derive(Clone)]
21pub struct Client {
22 api_key: String,
23 channel: Channel,
24}
25
26impl Debug for Client {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("Client")
29 .field("api_key", &"******")
30 .field("channel", &"Channel")
31 .finish()
32 }
33}
34
35#[derive(Clone)]
37pub struct ApiKeyInterceptor {
38 api_key: MetadataValue<tonic::metadata::Ascii>,
39 client_id: MetadataValue<tonic::metadata::Ascii>,
40}
41
42impl Interceptor for ApiKeyInterceptor {
43 fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
44 request
45 .metadata_mut()
46 .insert("x-goog-api-key", self.api_key.clone());
47 request
48 .metadata_mut()
49 .insert("x-goog-api-client", self.client_id.clone());
50 Ok(request)
51 }
52}
53
54impl Client {
55 pub async fn new(
57 api_key: impl Into<String>,
58 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
59 let api_key = api_key.into();
60 let endpoint = Endpoint::from_static(GEMINI_GRPC_ENDPOINT).tls_config(
61 tonic::transport::ClientTlsConfig::new()
62 .with_webpki_roots()
63 .domain_name("generativelanguage.googleapis.com"),
64 )?;
65
66 let channel = endpoint.connect().await?;
67
68 Ok(Self { api_key, channel })
69 }
70
71 pub(crate) fn grpc_client(
73 &self,
74 ) -> Result<
75 GenerativeServiceClient<
76 tonic::service::interceptor::InterceptedService<Channel, ApiKeyInterceptor>,
77 >,
78 Box<dyn std::error::Error + Send + Sync>,
79 > {
80 let api_key = MetadataValue::try_from(&self.api_key)?;
81 let client_id = MetadataValue::try_from(RIG_GRPC_CLIENT_IDENTIFIER)?;
82 let interceptor = ApiKeyInterceptor { api_key, client_id };
83
84 Ok(GenerativeServiceClient::with_interceptor(
85 self.channel.clone(),
86 interceptor,
87 ))
88 }
89}
90
91impl ProviderClient for Client {
92 type Input = String;
93 type Error = Box<dyn std::error::Error + Send + Sync>;
94
95 fn from_env() -> Result<Self, Self::Error> {
97 let api_key = std::env::var("GEMINI_API_KEY")?;
98 tokio::task::block_in_place(|| {
99 tokio::runtime::Handle::current().block_on(Self::new(api_key))
100 })
101 }
102
103 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
104 tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(Self::new(input)))
105 }
106}
107
108impl CompletionClient for Client {
109 type CompletionModel = CompletionModel;
110
111 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
112 CompletionModel::new(self.clone(), model)
113 }
114}
115
116impl EmbeddingsClient for Client {
117 type EmbeddingModel = EmbeddingModel;
118
119 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
120 EmbeddingModel::new(self.clone(), model, None)
121 }
122
123 fn embedding_model_with_ndims(
124 &self,
125 model: impl Into<String>,
126 ndims: usize,
127 ) -> Self::EmbeddingModel {
128 EmbeddingModel::new(self.clone(), model, Some(ndims))
129 }
130}