use crate::adapter::{AdapterDispatcher, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{ChatOptions, ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
use crate::embed::{EmbedOptions, EmbedOptionsSet, EmbedRequest, EmbedResponse};
use crate::resolver::AuthData;
use crate::{Client, Error, ModelIden, Result, ServiceTarget};
impl Client {
pub async fn all_model_names(&self, adapter_kind: AdapterKind) -> Result<Vec<String>> {
let models = AdapterDispatcher::all_model_names(adapter_kind).await?;
Ok(models)
}
pub fn default_model(&self, model_name: &str) -> Result<ModelIden> {
let adapter_kind = AdapterKind::from_model(model_name)?;
let model_iden = ModelIden::new(adapter_kind, model_name);
Ok(model_iden)
}
#[deprecated(note = "use `client.resolve_service_target(model_name)`")]
pub async fn resolve_model_iden(&self, model_name: &str) -> Result<ModelIden> {
let model = self.default_model(model_name)?;
let target = self.config().resolve_service_target(model).await?;
Ok(target.model)
}
pub async fn resolve_service_target(&self, model_name: &str) -> Result<ServiceTarget> {
let model = self.default_model(model_name)?;
self.config().resolve_service_target(model).await
}
pub async fn exec_chat(
&self,
model: &str,
chat_req: ChatRequest,
options: Option<&ChatOptions>,
) -> Result<ChatResponse> {
let options_set = ChatOptionsSet::default()
.with_chat_options(options)
.with_client_options(self.config().chat_options());
let model = self.default_model(model)?;
let target = self.config().resolve_service_target(model).await?;
let model = target.model.clone();
let auth_data = target.auth.clone();
let WebRequestData {
mut url,
mut headers,
payload,
} = AdapterDispatcher::to_web_request_data(target, ServiceType::Chat, chat_req, options_set.clone())?;
if let Some(extra_headers) = options.and_then(|o| o.extra_headers.as_ref()) {
headers.merge_with(&extra_headers);
}
if let AuthData::RequestOverride {
url: override_url,
headers: override_headers,
} = auth_data
{
url = override_url;
headers = override_headers;
};
let web_res = self
.web_client()
.do_post(&url, &headers, &payload)
.await
.map_err(|webc_error| Error::WebModelCall {
model_iden: model.clone(),
webc_error,
})?;
let captured_raw_body = options_set.capture_raw_body().unwrap_or_default().then(|| web_res.body.clone());
match AdapterDispatcher::to_chat_response(model.clone(), web_res, options_set) {
Ok(mut chat_res) => {
chat_res.captured_raw_body = captured_raw_body;
Ok(chat_res)
}
Err(err) => {
let response_body = captured_raw_body.unwrap_or_else(|| {
"Raw response not captured. Use the ChatOptions.capturre_raw_body flag to see raw response in this error".into()
});
let err = Error::ChatResponseGeneration {
model_iden: model,
request_payload: Box::new(payload),
response_body: Box::new(response_body),
cause: err.to_string(),
};
Err(err)
}
}
}
pub async fn exec_chat_stream(
&self,
model: &str,
chat_req: ChatRequest, options: Option<&ChatOptions>,
) -> Result<ChatStreamResponse> {
let options_set = ChatOptionsSet::default()
.with_chat_options(options)
.with_client_options(self.config().chat_options());
let model = self.default_model(model)?;
let target = self.config().resolve_service_target(model).await?;
let model = target.model.clone();
let auth_data = target.auth.clone();
let WebRequestData {
mut url,
mut headers,
payload,
} = AdapterDispatcher::to_web_request_data(target, ServiceType::ChatStream, chat_req, options_set.clone())?;
if let Some(extra_headers) = options.and_then(|o| o.extra_headers.as_ref()) {
headers.merge_with(&extra_headers);
}
if let AuthData::RequestOverride {
url: override_url,
headers: override_headers,
} = auth_data
{
url = override_url;
headers = override_headers;
};
let reqwest_builder = self
.web_client()
.new_req_builder(&url, &headers, &payload)
.map_err(|webc_error| Error::WebModelCall {
model_iden: model.clone(),
webc_error,
})?;
let res = AdapterDispatcher::to_chat_stream(model, reqwest_builder, options_set)?;
Ok(res)
}
pub async fn embed(
&self,
model: &str,
input: impl Into<String>,
options: Option<&EmbedOptions>,
) -> Result<EmbedResponse> {
let embed_req = EmbedRequest::new(input);
self.exec_embed(model, embed_req, options).await
}
pub async fn embed_batch(
&self,
model: &str,
inputs: Vec<String>,
options: Option<&EmbedOptions>,
) -> Result<EmbedResponse> {
let embed_req = EmbedRequest::new_batch(inputs);
self.exec_embed(model, embed_req, options).await
}
pub async fn exec_embed(
&self,
model: &str,
embed_req: EmbedRequest,
options: Option<&EmbedOptions>,
) -> Result<EmbedResponse> {
let options_set = EmbedOptionsSet::new()
.with_request_options(options)
.with_client_options(self.config().embed_options());
let model = self.default_model(model)?;
let target = self.config().resolve_service_target(model).await?;
let model = target.model.clone();
let WebRequestData { headers, payload, url } =
AdapterDispatcher::to_embed_request_data(target, embed_req, options_set.clone())?;
let web_res = self
.web_client()
.do_post(&url, &headers, &payload)
.await
.map_err(|webc_error| Error::WebModelCall {
model_iden: model.clone(),
webc_error,
})?;
let res = AdapterDispatcher::to_embed_response(model, web_res, options_set)?;
Ok(res)
}
}