use std::future::Future;
use std::time::Duration;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::error::{
ContextError, DurableError, EmbeddingError, HookError, ProviderError, ToolError,
};
use crate::stream::StreamHandle;
use crate::types::{
CompletionRequest, CompletionResponse, ContentItem, EmbeddingRequest, EmbeddingResponse,
Message, ToolContext, ToolDefinition, ToolOutput,
};
use crate::wasm::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
pub trait Provider: WasmCompatSend + WasmCompatSync {
fn complete(
&self,
request: CompletionRequest,
) -> impl Future<Output = Result<CompletionResponse, ProviderError>> + WasmCompatSend;
fn complete_stream(
&self,
request: CompletionRequest,
) -> impl Future<Output = Result<StreamHandle, ProviderError>> + WasmCompatSend;
}
pub trait EmbeddingProvider: WasmCompatSend + WasmCompatSync {
fn embed(
&self,
request: EmbeddingRequest,
) -> impl Future<Output = Result<EmbeddingResponse, EmbeddingError>> + WasmCompatSend;
}
pub trait Tool: WasmCompatSend + WasmCompatSync {
const NAME: &'static str;
type Args: DeserializeOwned + schemars::JsonSchema + WasmCompatSend;
type Output: Serialize;
type Error: std::error::Error + WasmCompatSend + 'static;
fn definition(&self) -> ToolDefinition;
fn call(
&self,
args: Self::Args,
ctx: &ToolContext,
) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
}
pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
fn name(&self) -> &str;
fn definition(&self) -> ToolDefinition;
fn call_dyn<'a>(
&'a self,
input: serde_json::Value,
ctx: &'a ToolContext,
) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>;
}
impl<T: Tool> ToolDyn for T {
fn name(&self) -> &str {
T::NAME
}
fn definition(&self) -> ToolDefinition {
Tool::definition(self)
}
fn call_dyn<'a>(
&'a self,
input: serde_json::Value,
ctx: &'a ToolContext,
) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
Box::pin(async move {
let args: T::Args = serde_json::from_value(input)
.map_err(|e| ToolError::InvalidInput(e.to_string()))?;
let output = self
.call(args, ctx)
.await
.map_err(|e| ToolError::ExecutionFailed(e.to_string().into()))?;
let structured = serde_json::to_value(&output)
.map_err(|e| ToolError::ExecutionFailed(Box::new(e)))?;
let text = match &structured {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
Ok(ToolOutput {
content: vec![ContentItem::Text(text)],
structured_content: Some(structured),
is_error: false,
})
})
}
}
pub trait ContextStrategy: WasmCompatSend + WasmCompatSync {
fn should_compact(&self, messages: &[Message], token_count: usize) -> bool;
fn compact(
&self,
messages: Vec<Message>,
) -> impl Future<Output = Result<Vec<Message>, ContextError>> + WasmCompatSend;
fn token_estimate(&self, messages: &[Message]) -> usize;
}
#[derive(Debug)]
pub enum HookEvent<'a> {
LoopIteration {
turn: usize,
},
PreLlmCall {
request: &'a CompletionRequest,
},
PostLlmCall {
response: &'a CompletionResponse,
},
PreToolExecution {
tool_name: &'a str,
input: &'a serde_json::Value,
},
PostToolExecution {
tool_name: &'a str,
output: &'a ToolOutput,
},
ContextCompaction {
old_tokens: usize,
new_tokens: usize,
},
SessionStart {
session_id: &'a str,
},
SessionEnd {
session_id: &'a str,
},
}
#[derive(Debug)]
pub enum HookAction {
Continue,
Skip {
reason: String,
},
Terminate {
reason: String,
},
}
pub trait ObservabilityHook: WasmCompatSend + WasmCompatSync {
fn on_event(
&self,
event: HookEvent<'_>,
) -> impl Future<Output = Result<HookAction, HookError>> + WasmCompatSend;
}
#[derive(Debug, Clone)]
pub struct ActivityOptions {
pub start_to_close_timeout: Duration,
pub heartbeat_timeout: Option<Duration>,
pub retry_policy: Option<RetryPolicy>,
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub initial_interval: Duration,
pub backoff_coefficient: f64,
pub maximum_attempts: u32,
pub maximum_interval: Duration,
pub non_retryable_errors: Vec<String>,
}
pub trait DurableContext: WasmCompatSend + WasmCompatSync {
fn execute_llm_call(
&self,
request: CompletionRequest,
options: ActivityOptions,
) -> impl Future<Output = Result<CompletionResponse, DurableError>> + WasmCompatSend;
fn execute_tool(
&self,
tool_name: &str,
input: serde_json::Value,
ctx: &ToolContext,
options: ActivityOptions,
) -> impl Future<Output = Result<ToolOutput, DurableError>> + WasmCompatSend;
fn wait_for_signal<T: DeserializeOwned + WasmCompatSend>(
&self,
signal_name: &str,
timeout: Duration,
) -> impl Future<Output = Result<Option<T>, DurableError>> + WasmCompatSend;
fn should_continue_as_new(&self) -> bool;
fn continue_as_new(
&self,
state: serde_json::Value,
) -> impl Future<Output = Result<(), DurableError>> + WasmCompatSend;
fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + WasmCompatSend;
fn now(&self) -> chrono::DateTime<chrono::Utc>;
}
#[derive(Debug, Clone)]
pub enum PermissionDecision {
Allow,
Deny(String),
Ask(String),
}
pub trait PermissionPolicy: WasmCompatSend + WasmCompatSync {
fn check(&self, tool_name: &str, input: &serde_json::Value) -> PermissionDecision;
}