pub mod oagw_responses;
pub mod provider_resolver;
pub mod providers;
pub mod request;
use std::pin::Pin;
use std::sync::LazyLock;
use std::task::{Context, Poll};
use futures::Stream;
use futures::StreamExt;
use modkit_security::SecurityContext;
use oagw_sdk::error::{ServiceGatewayError, StreamingError};
use regex::Regex;
use serde::Serialize;
use tokio_util::sync::CancellationToken;
pub use request::{
FeatureFlag, LlmMessage, LlmRequest, LlmRequestBuilder, LlmTool, RequestMetadata, RequestType,
Role, UserIdentity,
};
pub use providers::{ProviderKind, create_provider};
pub struct Streaming;
pub struct NonStreaming;
#[derive(Debug, thiserror::Error)]
pub enum LlmProviderError {
#[error("rate limited")]
RateLimited { retry_after_secs: Option<u64> },
#[error("provider timeout")]
Timeout,
#[error("provider error: {code}: {message}")]
ProviderError {
code: String,
message: String,
#[source]
raw_detail: Option<RawDetail>,
},
#[error("provider unavailable")]
ProviderUnavailable,
#[error("invalid response: {detail}")]
InvalidResponse { detail: String },
#[error("stream error: {0}")]
StreamError(#[from] StreamingError),
}
pub struct RawDetail(pub(crate) String);
impl std::fmt::Debug for RawDetail {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("RawDetail").field(&self.0).finish()
}
}
impl std::fmt::Display for RawDetail {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for RawDetail {}
impl LlmProviderError {
#[must_use]
pub fn raw_detail(&self) -> Option<&str> {
match self {
LlmProviderError::ProviderError {
raw_detail: Some(rd),
..
} => Some(&rd.0),
_ => None,
}
}
}
#[allow(clippy::unwrap_used)] static RE_RESP_ID: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(resp_|chatcmpl-|cmpl-|msg_)[A-Za-z0-9]+").unwrap());
#[allow(clippy::unwrap_used)]
static RE_URL: LazyLock<Regex> = LazyLock::new(|| Regex::new(r#"https?://[^\s,\])}"']+"#).unwrap());
#[allow(clippy::unwrap_used)]
static RE_CRED: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(sk-[A-Za-z0-9]{10,}|Bearer\s+[A-Za-z0-9._\-]+)").unwrap());
pub(crate) fn sanitize_provider_message(msg: &str) -> String {
let sanitized = RE_RESP_ID.replace_all(msg, "[provider_id]");
let sanitized = RE_URL.replace_all(&sanitized, "[url]");
RE_CRED.replace_all(&sanitized, "[credential]").into_owned()
}
impl From<ServiceGatewayError> for LlmProviderError {
fn from(err: ServiceGatewayError) -> Self {
match err {
ServiceGatewayError::RateLimitExceeded {
retry_after_secs, ..
} => LlmProviderError::RateLimited { retry_after_secs },
ServiceGatewayError::ConnectionTimeout { .. }
| ServiceGatewayError::RequestTimeout { .. } => LlmProviderError::Timeout,
ServiceGatewayError::UpstreamDisabled { .. } => LlmProviderError::ProviderUnavailable,
other => {
let raw = other.to_string();
let sanitized = sanitize_provider_message(&raw);
LlmProviderError::ProviderError {
code: "gateway_error".to_owned(),
message: sanitized,
raw_detail: Some(RawDetail(raw)),
}
}
}
}
}
pub use crate::domain::llm::{Citation, CitationSource, TextSpan, Usage};
#[derive(Debug)]
pub struct ResponseResult {
pub content: String,
pub usage: Usage,
pub response_id: String,
pub citations: Vec<Citation>,
pub raw_response: serde_json::Value,
}
#[derive(Debug)]
pub enum TerminalOutcome {
Completed {
usage: Usage,
response_id: String,
content: String,
citations: Vec<Citation>,
raw_response: serde_json::Value,
},
Failed {
error: LlmProviderError,
usage: Option<Usage>,
partial_content: String,
},
Incomplete {
reason: String,
usage: Usage,
partial_content: String,
},
}
#[derive(Debug)]
pub(crate) enum TranslatedEvent {
Sse(ClientSseEvent),
Terminal(TerminalOutcome),
Skip,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "event", content = "data")]
pub enum ClientSseEvent {
#[serde(rename = "delta")]
Delta {
r#type: &'static str,
content: String,
},
#[serde(rename = "tool")]
Tool {
phase: ToolPhase,
name: &'static str,
details: serde_json::Value,
},
#[serde(rename = "citations")]
Citations { items: Vec<Citation> },
}
pub use crate::domain::llm::ToolPhase;
pub struct ProviderStream {
#[allow(clippy::type_complexity)]
inner: Pin<Box<dyn Stream<Item = Result<TranslatedEvent, StreamingError>> + Send>>,
cancel: CancellationToken,
terminal: Option<TerminalOutcome>,
accumulated_text: String,
finished: bool,
}
impl std::fmt::Debug for ProviderStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProviderStream")
.field("cancelled", &self.cancel.is_cancelled())
.field("finished", &self.finished)
.field("accumulated_len", &self.accumulated_text.len())
.finish_non_exhaustive()
}
}
impl ProviderStream {
pub(crate) fn new(
inner: impl Stream<Item = Result<TranslatedEvent, StreamingError>> + Send + 'static,
cancel: CancellationToken,
) -> Self {
ProviderStream {
inner: Box::pin(inner),
cancel,
terminal: None,
accumulated_text: String::new(),
finished: false,
}
}
pub fn cancel(&self) {
self.cancel.cancel();
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.cancel.is_cancelled()
}
pub async fn into_outcome(mut self) -> TerminalOutcome {
loop {
match self.next().await {
Some(Ok(_)) => {} Some(Err(e)) => {
return TerminalOutcome::Failed {
error: LlmProviderError::StreamError(e),
usage: None,
partial_content: self.accumulated_text,
};
}
None => break,
}
}
if self.terminal.is_some() && !self.cancel.is_cancelled() {
let _drain = tokio::time::timeout(std::time::Duration::from_secs(2), async {
while self.inner.next().await.is_some() {}
})
.await;
}
match self.terminal {
Some(terminal) => terminal,
None if self.cancel.is_cancelled() => TerminalOutcome::Incomplete {
reason: "cancelled".to_owned(),
usage: Usage {
input_tokens: 0,
output_tokens: 0,
cache_read_input_tokens: 0,
cache_write_input_tokens: 0,
reasoning_tokens: 0,
},
partial_content: self.accumulated_text,
},
None => TerminalOutcome::Failed {
error: LlmProviderError::InvalidResponse {
detail: "stream ended without terminal event".to_owned(),
},
usage: None,
partial_content: self.accumulated_text,
},
}
}
}
impl Stream for ProviderStream {
type Item = Result<ClientSseEvent, StreamingError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.finished {
return Poll::Ready(None);
}
if this.cancel.is_cancelled() {
this.finished = true;
return Poll::Ready(None);
}
loop {
match this.inner.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(TranslatedEvent::Sse(event)))) => {
if let ClientSseEvent::Delta {
r#type: "text",
ref content,
} = event
{
this.accumulated_text.push_str(content);
}
return Poll::Ready(Some(Ok(event)));
}
Poll::Ready(Some(Ok(TranslatedEvent::Terminal(outcome)))) => {
this.finished = true;
this.terminal = Some(outcome);
return Poll::Ready(None);
}
Poll::Ready(Some(Ok(TranslatedEvent::Skip))) => {}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
this.finished = true;
return Poll::Ready(None);
}
Poll::Pending => {
if this.cancel.is_cancelled() {
this.finished = true;
return Poll::Ready(None);
}
return Poll::Pending;
}
}
}
}
}
#[async_trait::async_trait]
pub trait LlmProvider: Send + Sync {
async fn stream(
&self,
ctx: SecurityContext,
request: LlmRequest<Streaming>,
upstream_alias: &str,
cancel: CancellationToken,
) -> Result<ProviderStream, LlmProviderError>;
async fn complete(
&self,
ctx: SecurityContext,
request: LlmRequest<NonStreaming>,
upstream_alias: &str,
) -> Result<ResponseResult, LlmProviderError>;
}
#[must_use]
pub fn llm_request(model: impl Into<String>) -> LlmRequestBuilder {
LlmRequestBuilder::new(model)
}
#[cfg(test)]
#[path = "mod_tests.rs"]
mod mod_tests;