use crate::{
ElicitError, ElicitErrorKind, ElicitResult, Elicitation, ElicitationStyle, TypeMetadata,
};
use std::any::TypeId;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
pub trait ElicitCommunicator: Clone + Send + Sync {
fn send_prompt(
&self,
prompt: &str,
) -> impl std::future::Future<Output = ElicitResult<String>> + Send;
fn call_tool(
&self,
params: rmcp::model::CallToolRequestParams,
) -> impl std::future::Future<
Output = Result<rmcp::model::CallToolResult, rmcp::service::ServiceError>,
> + Send;
fn style_context(&self) -> &StyleContext;
fn with_style<T: 'static, S: ElicitationStyle>(&self, style: S) -> Self;
fn style_or_default<T: Elicitation + 'static>(&self) -> ElicitResult<T::Style>
where
T::Style: ElicitationStyle,
{
Ok(self
.style_context()
.get_style::<T, T::Style>()?
.unwrap_or_default())
}
fn style_or_elicit<T: Elicitation + 'static>(
&self,
) -> impl std::future::Future<Output = ElicitResult<T::Style>> + Send
where
T::Style: ElicitationStyle,
{
async move {
if let Some(style) = self.style_context().get_style::<T, T::Style>()? {
Ok(style)
} else {
T::Style::elicit(self).await
}
}
}
fn elicitation_context(&self) -> &ElicitationContext;
fn current_type(&self) -> ElicitResult<Option<TypeMetadata>> {
self.elicitation_context().current()
}
fn current_depth(&self) -> ElicitResult<usize> {
self.elicitation_context().depth()
}
fn elicitation_stack(&self) -> ElicitResult<Vec<TypeMetadata>> {
self.elicitation_context().stack()
}
}
#[derive(Clone, Default)]
pub struct StyleContext {
styles: Arc<RwLock<HashMap<TypeId, Box<dyn std::any::Any + Send + Sync>>>>,
}
impl StyleContext {
#[tracing::instrument(skip(self, style), level = "debug", fields(type_id = ?TypeId::of::<T>()))]
pub fn set_style<T: 'static, S: ElicitationStyle>(&mut self, style: S) -> ElicitResult<()> {
let type_id = TypeId::of::<T>();
let mut styles = self.styles.write().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"StyleContext lock poisoned: {}",
e
)))
})?;
styles.insert(type_id, Box::new(style));
Ok(())
}
#[tracing::instrument(skip(self), level = "debug", fields(type_id = ?TypeId::of::<T>()))]
pub fn get_style<T: 'static, S: ElicitationStyle>(&self) -> ElicitResult<Option<S>> {
let type_id = TypeId::of::<T>();
let styles = self.styles.read().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"StyleContext lock poisoned: {}",
e
)))
})?;
Ok(styles
.get(&type_id)
.and_then(|boxed| boxed.downcast_ref::<S>())
.cloned())
}
}
#[derive(Clone, Default)]
pub struct ElicitationContext {
stack: Arc<RwLock<Vec<TypeMetadata>>>,
}
impl ElicitationContext {
pub fn push(&self, metadata: TypeMetadata) -> ElicitResult<()> {
let mut stack = self.stack.write().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"ElicitationContext lock poisoned: {}",
e
)))
})?;
stack.push(metadata.clone());
tracing::debug!(
type_name = metadata.type_name,
depth = stack.len(),
"Entering elicitation"
);
Ok(())
}
pub fn pop(&self) -> ElicitResult<()> {
let mut stack = self.stack.write().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"ElicitationContext lock poisoned: {}",
e
)))
})?;
if let Some(metadata) = stack.pop() {
tracing::debug!(
type_name = metadata.type_name,
depth = stack.len(),
"Exiting elicitation"
);
}
Ok(())
}
pub fn current(&self) -> ElicitResult<Option<TypeMetadata>> {
let stack = self.stack.read().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"ElicitationContext lock poisoned: {}",
e
)))
})?;
Ok(stack.last().cloned())
}
pub fn depth(&self) -> ElicitResult<usize> {
let stack = self.stack.read().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"ElicitationContext lock poisoned: {}",
e
)))
})?;
Ok(stack.len())
}
pub fn stack(&self) -> ElicitResult<Vec<TypeMetadata>> {
let stack = self.stack.read().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"ElicitationContext lock poisoned: {}",
e
)))
})?;
Ok(stack.clone())
}
}