use async_trait::async_trait;
use serde::{Serialize, de::DeserializeOwned};
use std::fmt::Debug;
pub trait SecretaryInput:
Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static
{
}
pub trait SecretaryOutput:
Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static
{
}
#[async_trait]
pub trait SecretaryBehavior: Send + Sync {
type Input: SecretaryInput;
type Output: SecretaryOutput;
type State: Send + Sync + 'static;
async fn handle_input(
&self,
input: Self::Input,
ctx: &mut super::context::SecretaryContext<Self::State>,
) -> anyhow::Result<Vec<Self::Output>>;
fn welcome_message(&self) -> Option<Self::Output> {
None
}
fn initial_state(&self) -> Self::State;
async fn periodic_check(
&self,
_ctx: &mut super::context::SecretaryContext<Self::State>,
) -> anyhow::Result<Vec<Self::Output>> {
Ok(vec![])
}
async fn on_disconnect(
&self,
_ctx: &mut super::context::SecretaryContext<Self::State>,
) -> anyhow::Result<()> {
Ok(())
}
fn handle_error(&self, _error: &anyhow::Error) -> Option<Self::Output> {
None
}
}
#[async_trait]
pub trait PhaseHandler<Input, Output, State>: Send + Sync
where
Input: Send + 'static,
Output: Send + 'static,
State: Send + Sync + 'static,
{
fn name(&self) -> &str;
async fn handle(
&self,
input: Input,
ctx: &mut super::context::SecretaryContext<State>,
) -> anyhow::Result<PhaseResult<Output>>;
fn can_skip(&self, _input: &Input, _ctx: &super::context::SecretaryContext<State>) -> bool {
false
}
}
#[derive(Debug, Clone)]
pub enum PhaseResult<T> {
Continue(T),
NeedInput {
partial_result: Option<T>,
prompt: String,
},
Skip,
Abort { reason: String },
}
#[async_trait]
pub trait WorkflowOrchestrator<Input, Output, State>: Send + Sync
where
Input: Send + 'static,
Output: Send + 'static,
State: Send + Sync + 'static,
{
fn name(&self) -> &str;
async fn execute(
&self,
input: Input,
ctx: &mut super::context::SecretaryContext<State>,
) -> anyhow::Result<WorkflowResult<Output>>;
}
#[derive(Debug, Clone)]
pub enum WorkflowResult<T> {
Completed(T),
NeedInput(String),
Skipped,
Aborted(String),
}
#[async_trait]
pub trait InputHandler<Input, Output, State>: Send + Sync
where
Input: Send + 'static,
Output: Send + 'static,
State: Send + Sync + 'static,
{
fn name(&self) -> &str;
fn can_handle(&self, input: &Input) -> bool;
async fn handle(
&self,
input: Input,
ctx: &mut super::context::SecretaryContext<State>,
) -> anyhow::Result<Vec<Output>>;
}
#[derive(Debug)]
pub enum SecretaryEvent<State> {
Started,
Stopped,
InputReceived,
OutputSent,
StateChanged,
Custom(String),
#[doc(hidden)]
_Phantom(std::marker::PhantomData<State>),
}
#[async_trait]
pub trait EventListener<State>: Send + Sync
where
State: Send + Sync + 'static,
{
fn name(&self) -> &str;
async fn on_event(
&self,
event: &SecretaryEvent<State>,
ctx: &super::context::SecretaryContext<State>,
);
}
#[async_trait]
pub trait Middleware<Input, Output, State>: Send + Sync
where
Input: Send + Clone + 'static,
Output: Send + 'static,
State: Send + Sync + 'static,
{
fn name(&self) -> &str;
async fn before_handle(
&self,
_input: &Input,
_ctx: &super::context::SecretaryContext<State>,
) -> Option<Vec<Output>> {
None
}
async fn after_handle(
&self,
_input: &Input,
outputs: Vec<Output>,
_ctx: &super::context::SecretaryContext<State>,
) -> Vec<Output> {
outputs
}
}
impl<T> SecretaryInput for T where
T: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static
{
}
impl<T> SecretaryOutput for T where
T: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static
{
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
enum TestInput {
Text(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
enum TestOutput {
Reply(String),
}
fn _assert_input_impl<T: SecretaryInput>() {}
fn _assert_output_impl<T: SecretaryOutput>() {}
#[test]
fn test_auto_impl() {
_assert_input_impl::<TestInput>();
_assert_output_impl::<TestOutput>();
}
}