pub mod proto;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tonic::transport::Channel;
use crate::error::Error;
#[derive(Clone)]
pub struct ExternalChannel {
inner: Channel,
prefix: String,
}
pub async fn connect(api_url: &str) -> Result<ExternalChannel, Error> {
let uri: http::Uri = api_url
.parse()
.map_err(|e| Error::External(format!("Invalid API URL '{api_url}': {e}")))?;
let raw_path = uri.path().trim_end_matches('/');
let prefix = if raw_path.is_empty() || raw_path == "/" {
String::new()
} else {
raw_path.to_string()
};
if !prefix.is_empty() {
let probe = format!("{prefix}/chat.external.SmokeTest/Method");
probe
.parse::<http::uri::PathAndQuery>()
.map_err(|e| Error::InvalidPathPrefix {
prefix: prefix.clone(),
reason: e.to_string(),
})?;
}
let base_url = if prefix.is_empty() {
api_url.to_string()
} else {
let scheme = uri.scheme_str().unwrap_or("https");
let authority = uri
.authority()
.map(|a| a.as_str())
.ok_or_else(|| Error::External(format!("Missing authority in URL: {api_url}")))?;
format!("{scheme}://{authority}")
};
let endpoint = tonic::transport::Endpoint::from_shared(base_url.clone())
.map_err(|e| Error::External(format!("Invalid API URL '{base_url}': {e}")))?
.connect_timeout(Duration::from_secs(10))
.timeout(Duration::from_secs(30));
let endpoint = if api_url.starts_with("https://") {
endpoint
.tls_config(tonic::transport::ClientTlsConfig::new().with_enabled_roots())
.map_err(|e| Error::External(format!("TLS configuration failed: {e}")))?
} else {
endpoint
};
let channel = endpoint
.connect()
.await
.map_err(|e| Error::External(format!("Failed to connect to {base_url}: {e}")))?;
Ok(ExternalChannel {
inner: channel,
prefix,
})
}
pub fn auth_request<T>(api_key: &str, body: T) -> Result<tonic::Request<T>, Error> {
let value = format!("Bearer {api_key}")
.parse::<tonic::metadata::MetadataValue<_>>()
.map_err(|_| Error::InvalidApiKey)?;
let mut req = tonic::Request::new(body);
req.metadata_mut().insert("authorization", value);
Ok(req)
}
impl tower_service::Service<http::Request<tonic::body::Body>> for ExternalChannel {
type Response = http::Response<tonic::body::Body>;
type Error = tonic::transport::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
tower_service::Service::poll_ready(&mut self.inner, cx)
}
fn call(&mut self, mut req: http::Request<tonic::body::Body>) -> Self::Future {
if !self.prefix.is_empty() && !prepend_path_prefix(&mut req, &self.prefix) {
*req.uri_mut() = "/__pcs_external_invalid_path__".parse().unwrap_or_else(|_| {
req.uri().clone()
});
}
let fut = tower_service::Service::call(&mut self.inner, req);
Box::pin(fut)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn auth_request_accepts_normal_key() {
let req = auth_request("pk_live_abc123", ()).unwrap();
let auth = req.metadata().get("authorization").unwrap();
assert_eq!(auth, "Bearer pk_live_abc123");
}
#[test]
fn auth_request_rejects_newline_in_key() {
let result = auth_request("pk_live_abc\r\nX-Injected: bad", ());
assert!(matches!(result, Err(Error::InvalidApiKey)));
}
#[test]
fn auth_request_rejects_nul_in_key() {
let result = auth_request("pk_live_abc\0nul", ());
assert!(matches!(result, Err(Error::InvalidApiKey)));
}
}
fn prepend_path_prefix(req: &mut http::Request<tonic::body::Body>, prefix: &str) -> bool {
let pq_str = req
.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/");
let new_path = format!("{prefix}{pq_str}");
let Ok(new_pq) = new_path.parse::<http::uri::PathAndQuery>() else {
return false;
};
let mut parts = req.uri().clone().into_parts();
parts.path_and_query = Some(new_pq);
let Ok(new_uri) = http::Uri::from_parts(parts) else {
return false;
};
*req.uri_mut() = new_uri;
true
}