use crate::tools::openapi::{OpenApiOperationTool, OpenApiSpec};
use crate::tools::{BaseTool, BaseToolset};
use std::path::Path;
use std::sync::Arc;
pub struct OpenApiToolSet {
name: String,
spec_path: String,
tools: Vec<Box<dyn BaseTool>>,
}
impl OpenApiToolSet {
#[allow(clippy::unused_async)] pub async fn from_file(
name: String,
path: impl AsRef<Path>,
auth: Option<AuthConfig>,
) -> Result<Self, String> {
let spec_path = path.as_ref().to_string_lossy().to_string();
let spec = OpenApiSpec::from_file(path)?;
Self::from_spec(name, spec_path, spec, auth.as_ref())
}
pub async fn from_url(
name: String,
url: &str,
auth: Option<AuthConfig>,
) -> Result<Self, String> {
let spec = OpenApiSpec::from_url(url).await?;
Self::from_spec(name, url.to_string(), spec, auth.as_ref())
}
fn from_spec(
name: String,
spec_path: String,
spec: OpenApiSpec,
auth: Option<&AuthConfig>,
) -> Result<Self, String> {
let spec = Arc::new(spec);
let http_client = Arc::new(Self::create_http_client(auth)?);
let mut tools: Vec<Box<dyn BaseTool>> = Vec::new();
for (path, path_item_ref) in &spec.spec().paths.paths {
let path_item = match path_item_ref {
openapiv3::ReferenceOr::Item(item) => item,
openapiv3::ReferenceOr::Reference { .. } => continue,
};
let mut add_operation = |method: &str, operation: &openapiv3::Operation| {
let operation_id = operation.operation_id.clone().unwrap_or_else(|| {
let path_normalized = path
.trim_start_matches('/')
.replace('/', "_")
.replace('{', "by_")
.replace('}', "");
format!("{}_{}", method.to_lowercase(), path_normalized)
});
let description = operation
.summary
.clone()
.or_else(|| operation.description.clone())
.unwrap_or_else(|| format!("{method} {path}"));
let tool = Box::new(OpenApiOperationTool::new(
operation_id,
description,
method.to_string(),
path.clone(),
spec.clone(),
http_client.clone(),
auth.cloned(),
));
tools.push(tool);
};
if let Some(op) = &path_item.get {
add_operation("GET", op);
}
if let Some(op) = &path_item.post {
add_operation("POST", op);
}
if let Some(op) = &path_item.put {
add_operation("PUT", op);
}
if let Some(op) = &path_item.delete {
add_operation("DELETE", op);
}
if let Some(op) = &path_item.patch {
add_operation("PATCH", op);
}
if let Some(op) = &path_item.head {
add_operation("HEAD", op);
}
if let Some(op) = &path_item.options {
add_operation("OPTIONS", op);
}
if let Some(op) = &path_item.trace {
add_operation("TRACE", op);
}
}
Ok(Self {
name,
spec_path,
tools,
})
}
fn create_http_client(auth: Option<&AuthConfig>) -> Result<reqwest::Client, String> {
let mut builder = reqwest::Client::builder();
#[cfg(not(all(target_os = "wasi", target_env = "p1")))]
{
builder = builder.timeout(std::time::Duration::from_secs(30));
}
if let Some(AuthConfig::ApiKey {
location: HeaderOrQuery::Header,
name,
value,
}) = auth
{
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::HeaderName::from_bytes(name.as_bytes())
.map_err(|e| format!("Invalid header name: {e}"))?,
reqwest::header::HeaderValue::from_str(value)
.map_err(|e| format!("Invalid header value: {e}"))?,
);
builder = builder.default_headers(headers);
}
builder.build().map_err(|e| e.to_string())
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn spec_path(&self) -> &str {
&self.spec_path
}
}
#[derive(Debug, Clone)]
pub enum AuthConfig {
Basic {
username: String,
password: String,
},
ApiKey {
location: HeaderOrQuery,
name: String,
value: String,
},
}
#[derive(Debug, Clone, Copy)]
pub enum HeaderOrQuery {
Header,
Query,
}
#[cfg_attr(all(target_os = "wasi", target_env = "p1"), async_trait::async_trait(?Send))]
#[cfg_attr(
not(all(target_os = "wasi", target_env = "p1")),
async_trait::async_trait
)]
impl BaseToolset for OpenApiToolSet {
async fn get_tools(&self) -> Vec<&dyn BaseTool> {
self.tools.iter().map(std::convert::AsRef::as_ref).collect()
}
async fn close(&self) {
}
}
impl std::fmt::Debug for OpenApiToolSet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OpenApiToolSet")
.field("name", &self.name)
.field("spec_path", &self.spec_path)
.field("tools_count", &self.tools.len())
.finish()
}
}