use core::fmt::Debug;
use std::sync::Arc;
use gcp_vertex_ai_generative_language::google::ai::generativelanguage::v1beta2::content_filter::BlockedReason;
use gcp_vertex_ai_generative_language::google::ai::generativelanguage::v1beta2::{
CountMessageTokensRequest, Example, GenerateMessageRequest, GetModelRequest, Message,
MessagePrompt,
};
use gcp_vertex_ai_generative_language::{Credentials, LanguageClient};
use tokio::sync::Mutex;
use tracing::warn;
use crate::models;
use crate::models::{
ChatEntryTokenNumber, ChatInput, Error, ModelRef, ModelResponse, Role, SupportedModel,
};
#[derive(Clone)]
pub struct LanguageModel {
model: SupportedModel,
pub temperature: Option<f32>,
client: Arc<Mutex<LanguageClient>>,
}
impl Debug for LanguageModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LanguageModel")
.field("temperature", &self.temperature)
.field("model", &self.model)
.finish()
}
}
pub async fn build(api_key: String, temperature: Option<f32>) -> Result<ModelRef, Error> {
let client = LanguageClient::new(Credentials::ApiKey(api_key))
.await
.unwrap();
let model = LanguageModel {
model: SupportedModel::ChatBison001,
temperature,
client: Arc::new(Mutex::new(client)),
};
Ok(Arc::new(Box::new(model)))
}
impl LanguageModel {
fn prepare_input(&self, input: ChatInput) -> Result<MessagePrompt, Error> {
let context = input
.context
.iter()
.map(|c| c.msg.to_string())
.collect::<Vec<String>>()
.join("\n");
let examples = input
.examples
.iter()
.map(|(user, bot)| Example {
input: Some(Message {
author: Role::User.to_string(),
content: user.msg.to_string(),
citation_metadata: None,
}),
output: Some(Message {
author: Role::Assistant.to_string(),
content: bot.msg.to_string(),
citation_metadata: None,
}),
})
.collect();
let messages = input
.chat
.iter()
.map(|m| Message {
author: m.role.to_string(),
content: m.msg.to_string(),
citation_metadata: None,
})
.collect();
let message_prompt = MessagePrompt {
context,
examples,
messages,
};
Ok(message_prompt)
}
}
#[async_trait::async_trait]
impl ChatEntryTokenNumber for LanguageModel {
async fn num_tokens(&self, input: ChatInput) -> usize {
let prompt = self.prepare_input(input).unwrap();
let req = CountMessageTokensRequest {
model: format!("models/{}", self.model),
prompt: Some(prompt),
};
let mut client = self.client.lock().await;
let resp = client
.discuss_service
.count_message_tokens(req)
.await
.unwrap();
resp.get_ref().token_count as usize
}
async fn context_size(&self) -> usize {
let mut client = self.client.lock().await;
let req = GetModelRequest {
name: format!("models/{}", self.model),
};
client
.model_service
.get_model(req)
.await
.unwrap()
.get_ref()
.input_token_limit as usize
}
}
#[async_trait::async_trait]
impl models::Model for LanguageModel {
async fn query(
&self,
input: ChatInput,
_max_tokens: Option<usize>,
) -> Result<ModelResponse, Error> {
let prompt = self.prepare_input(input).unwrap();
let req = GenerateMessageRequest {
model: format!("models/{}", self.model),
prompt: Some(prompt),
temperature: self.temperature,
candidate_count: Some(1),
top_p: None,
top_k: None,
};
let mut client = self.client.lock().await;
let resp = client
.discuss_service
.generate_message(req)
.await
.map_err(gcp_vertex_ai_generative_language::Error::from)?;
let resp = resp.get_ref();
if resp.candidates.is_empty() {
if !resp.filters.is_empty() {
resp.filters.iter().for_each(|f| {
if let Some(message) = f.message.as_ref() {
warn!(
"Filter: {:?} - {}",
BlockedReason::try_from(f.reason).unwrap_or(BlockedReason::Unspecified),
message
);
} else {
warn!(
"Filter: {:?}",
BlockedReason::try_from(f.reason).unwrap_or(BlockedReason::Unspecified)
)
}
});
return Err(Error::Filtered);
}
return Err(Error::NoResponseFromModel);
}
Ok(ModelResponse {
msg: resp.candidates[0].content.clone(),
usage: None,
finish_reason: None,
})
}
}