use reqwest::{IntoUrl, Url};
use crate::{
Error, Result,
core::{Message, language_model::LanguageModelOptions, messages::TaggedMessage},
};
pub fn step_count_is(step: usize) -> impl Fn(&LanguageModelOptions) -> bool {
move |options| options.steps().len() > step
}
pub(crate) fn resolve_message(
options: &LanguageModelOptions,
prompt: &Option<String>,
) -> (String, Vec<TaggedMessage>) {
let messages = if options.messages.is_empty() {
let mut msgs = Vec::new();
if let Some(ref system) = options.system
&& !system.is_empty()
{
msgs.push(TaggedMessage::initial_step_msg(Message::System(
system.clone().into(),
)));
}
msgs.push(TaggedMessage::initial_step_msg(Message::User(
prompt.to_owned().unwrap_or_default().into(),
)));
msgs
} else {
options.messages.to_vec()
};
let system = options.system.to_owned().unwrap_or_else(|| {
messages
.iter()
.find_map(|m| match m.message {
Message::System(ref s) => Some(s.content.to_string()),
_ => None,
})
.unwrap_or_default()
});
(system, messages)
}
pub(crate) fn sum_options(a: Option<usize>, b: Option<usize>) -> Option<usize> {
match (a, b) {
(Some(x), Some(y)) => Some(x + y),
_ => a.or(b),
}
}
#[allow(dead_code)]
pub(crate) fn validate_base_url(s: &str) -> crate::error::Result<String> {
use reqwest::Url;
let url = s
.parse::<Url>()
.map_err(|_| Error::InvalidInput("Invalid base URL".into()))?;
if !matches!(url.scheme(), "http" | "https") {
return Err(Error::InvalidInput(
"Base URL must start with http:// or https://".into(),
));
}
if url.host_str().is_none() {
return Err(Error::InvalidInput("Base URL must include a host".into()));
}
Ok(url.to_string())
}
pub(crate) fn join_url(base_url: impl IntoUrl, path: &str) -> Result<Url> {
let base_url = base_url
.into_url()
.map_err(|_| Error::InvalidInput("Invalid base URL".into()))?;
let base_str = base_url.as_str().trim_end_matches('/');
let path_str = path.trim_start_matches('/');
let full_url = format!("{base_str}/{path_str}");
Url::parse(&full_url)
.map_err(|_| Error::InvalidInput("Failed to join base URL and path".into()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sum_options_both_some() {
assert_eq!(sum_options(Some(1), Some(2)), Some(3));
assert_eq!(sum_options(Some(0), Some(0)), Some(0));
assert_eq!(sum_options(Some(10), Some(20)), Some(30));
}
#[test]
fn test_sum_options_first_some_second_none() {
assert_eq!(sum_options(Some(5), None), Some(5));
assert_eq!(sum_options(Some(0), None), Some(0));
}
#[test]
fn test_sum_options_first_none_second_some() {
assert_eq!(sum_options(None, Some(7)), Some(7));
assert_eq!(sum_options(None, Some(0)), Some(0));
}
#[test]
fn test_sum_options_both_none() {
assert_eq!(sum_options(None, None), None);
}
#[test]
fn test_join_url() {
let url = join_url("https://api.example.com/v1", "chat/completions").unwrap();
assert_eq!(url.as_str(), "https://api.example.com/v1/chat/completions");
let url = join_url("https://api.example.com/v1/", "chat/completions").unwrap();
assert_eq!(url.as_str(), "https://api.example.com/v1/chat/completions");
let url = join_url("https://api.example.com/v1", "/chat/completions").unwrap();
assert_eq!(url.as_str(), "https://api.example.com/v1/chat/completions");
let url = join_url("https://api.example.com/v1/", "/chat/completions").unwrap();
assert_eq!(url.as_str(), "https://api.example.com/v1/chat/completions");
}
}