use reqwest::header::HeaderValue;
use serde::Serialize;
use crate::{
config::{RecordReplayMode, REPLAY_LOG},
KagiConfig, KagiError, KagiResult,
};
pub struct KagiClient {
pub config: KagiConfig,
pub inner_client: reqwest_middleware::ClientWithMiddleware,
}
impl KagiClient {
pub fn new(config: KagiConfig) -> Self {
let reqwest_client = reqwest::Client::new();
let client_builder = reqwest_middleware::ClientBuilder::new(reqwest_client);
let inner_client = client_builder.build();
Self {
config,
inner_client,
}
}
pub fn with_replay(
config: KagiConfig,
reqwest_client: reqwest::Client,
record_replay_mode: RecordReplayMode,
) -> KagiResult<Self> {
let client_builder = reqwest_middleware::ClientBuilder::new(reqwest_client);
let Some(vcr_logfile) = config.logdir.as_ref().map(|path| path.join(REPLAY_LOG)) else {
let inner_client = client_builder.build();
let client = KagiClient {
config,
inner_client,
};
return Ok(client);
};
let middleware = rvcr::VCRMiddleware::try_from(vcr_logfile.to_path_buf())
.map_err(KagiError::VCRMiddlewareError)?
.with_mode(rvcr::VCRMode::from(record_replay_mode));
let inner_client = client_builder.with(middleware).build();
let client = KagiClient {
config,
inner_client,
};
Ok(client)
}
pub async fn query<Q>(&self, endpoint: &str, request_body: Q) -> KagiResult<reqwest::Response>
where
Q: Serialize,
{
let request_url = reqwest::Url::parse(&format!("{}{}", self.config.api_base, endpoint))
.map_err(KagiError::UrlError)?;
let content_type: Result<HeaderValue, _> = "application/json".parse();
let authorization: Result<HeaderValue, _> = format!("Bot {}", self.config.api_key).parse();
let request = self
.inner_client
.post(request_url)
.header("Content-Type", content_type.unwrap())
.header("Authorization", authorization.unwrap())
.json(&request_body);
let response = request
.send()
.await
.map_err(KagiError::ReqwestMiddlewareError)?;
if response.status() != reqwest::StatusCode::OK {
return Err(KagiError::StatusCodeError(response.status()));
}
Ok(response)
}
}