pub mod auth;
pub use auth::Authentication;
use google::ai::generativelanguage::v1beta2::discuss_service_client::DiscussServiceClient;
use google::ai::generativelanguage::v1beta2::model_service_client::ModelServiceClient;
use google::ai::generativelanguage::v1beta2::text_service_client::TextServiceClient;
use tonic::codegen::http::uri::InvalidUri;
use tonic::transport::{Certificate, Channel, ClientTlsConfig};
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("tonic transport error - {0}")]
Tonic(#[from] tonic::transport::Error),
#[error("{0}")]
InvalidUri(#[from] InvalidUri),
#[error("Status: {}", .0.message())]
Status(#[from] tonic::Status),
}
const CERTIFICATES: &str = include_str!("../certs/roots.pem");
#[derive(Clone)]
pub enum Credentials {
ApiKey(String),
None,
}
#[allow(missing_docs)]
pub mod google {
pub mod api {
include!(concat!(env!("OUT_DIR"), "/google.api.rs"));
}
pub mod ai {
pub mod generativelanguage {
pub mod v1beta2 {
include!(concat!(
env!("OUT_DIR"),
"/google.ai.generativelanguage.v1beta2.rs"
));
}
}
}
}
#[derive(Clone)]
pub struct LanguageClient {
pub discuss_service: DiscussServiceClient<
tonic::service::interceptor::InterceptedService<Channel, Authentication>,
>,
pub model_service: ModelServiceClient<
tonic::service::interceptor::InterceptedService<Channel, Authentication>,
>,
pub text_service:
TextServiceClient<tonic::service::interceptor::InterceptedService<Channel, Authentication>>,
}
impl LanguageClient {
pub async fn new(credentials: Credentials) -> Result<Self, Error> {
let domain_name = "generativelanguage.googleapis.com".to_string();
let tls_config = ClientTlsConfig::new()
.ca_certificate(Certificate::from_pem(CERTIFICATES))
.domain_name(&domain_name);
let endpoint = format!("https://{endpoint}", endpoint = domain_name);
let channel = Channel::from_shared(endpoint)?
.user_agent("github.com/ssoudan/gcp-vertex-ai-generative-ai")?
.tls_config(tls_config)?
.connect_lazy();
Self::from_channel(credentials, channel).await
}
pub async fn from_channel(
credentials: Credentials,
channel: Channel,
) -> Result<LanguageClient, Error> {
let discuss_service = {
let auth = Authentication::build(credentials.clone()).await?;
DiscussServiceClient::with_interceptor(channel.clone(), auth)
};
let model_service = {
let auth = Authentication::build(credentials.clone()).await?;
ModelServiceClient::with_interceptor(channel.clone(), auth)
};
let text_service = {
let auth = Authentication::build(credentials).await?;
TextServiceClient::with_interceptor(channel, auth)
};
Ok(Self {
discuss_service,
model_service,
text_service,
})
}
}
#[cfg(test)]
mod test;
#[cfg(test)]
mod common {
use std::env;
use crate::{Credentials, LanguageClient};
pub(crate) async fn test_client() -> LanguageClient {
let api_key = env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY must be set");
LanguageClient::new(Credentials::ApiKey(api_key))
.await
.unwrap()
}
}