use crate::{ElicitError, ElicitErrorKind, ElicitResult, Elicitation, StyleMarker, TypeMetadata};
use std::any::TypeId;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
pub trait ElicitCommunicator: Clone + Send + Sync {
fn send_prompt(
&self,
prompt: &str,
) -> impl std::future::Future<Output = ElicitResult<String>> + Send;
fn call_tool(
&self,
params: rmcp::model::CallToolRequestParams,
) -> impl std::future::Future<
Output = Result<rmcp::model::CallToolResult, rmcp::service::ServiceError>,
> + Send;
fn style_context(&self) -> &StyleContext;
fn with_style<T: 'static, S: StyleMarker + crate::style::ElicitationStyle + 'static>(
&self,
style: S,
) -> Self;
fn style_or_default<T: Elicitation + 'static>(&self) -> ElicitResult<T::Style>
where
T::Style: StyleMarker,
{
Ok(self
.style_context()
.get_style::<T, T::Style>()?
.unwrap_or_default())
}
fn style_or_elicit<T: Elicitation + 'static>(
&self,
) -> impl std::future::Future<Output = ElicitResult<T::Style>> + Send
where
T::Style: StyleMarker,
{
async move {
if let Some(style) = self.style_context().get_style::<T, T::Style>()? {
Ok(style)
} else {
T::Style::elicit(self).await
}
}
}
fn elicitation_context(&self) -> &ElicitationContext;
fn current_type(&self) -> ElicitResult<Option<TypeMetadata>> {
self.elicitation_context().current()
}
fn current_depth(&self) -> ElicitResult<usize> {
self.elicitation_context().depth()
}
fn elicitation_stack(&self) -> ElicitResult<Vec<TypeMetadata>> {
self.elicitation_context().stack()
}
}
trait StyleEntry: Send + Sync {
fn as_any(&self) -> &dyn std::any::Any;
fn as_style(&self) -> &dyn crate::style::ElicitationStyle;
}
impl<T: crate::style::ElicitationStyle + 'static> StyleEntry for T {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_style(&self) -> &dyn crate::style::ElicitationStyle {
self
}
}
#[derive(Clone, Default)]
pub struct StyleContext {
styles: Arc<RwLock<HashMap<TypeId, Box<dyn StyleEntry>>>>,
}
impl StyleContext {
#[tracing::instrument(skip(self, style), level = "debug", fields(type_id = ?TypeId::of::<T>()))]
pub fn set_style<T: 'static, S: StyleMarker + crate::style::ElicitationStyle + 'static>(
&mut self,
style: S,
) -> ElicitResult<()> {
let type_id = TypeId::of::<T>();
let mut styles = self.styles.write().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"StyleContext lock poisoned: {}",
e
)))
})?;
styles.insert(type_id, Box::new(style));
Ok(())
}
#[tracing::instrument(skip(self), level = "debug", fields(type_id = ?TypeId::of::<T>()))]
pub fn get_style<T: 'static, S: StyleMarker>(&self) -> ElicitResult<Option<S>> {
let type_id = TypeId::of::<T>();
let styles = self.styles.read().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"StyleContext lock poisoned: {}",
e
)))
})?;
Ok(styles
.get(&type_id)
.and_then(|entry| entry.as_any().downcast_ref::<S>())
.cloned())
}
#[tracing::instrument(skip(self), level = "debug", fields(type_id = ?TypeId::of::<T>()))]
pub fn prompt_for_type<T: 'static>(
&self,
field_name: &str,
field_type: &str,
context: &crate::style::PromptContext,
) -> ElicitResult<Option<String>> {
let type_id = TypeId::of::<T>();
let styles = self.styles.read().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"StyleContext lock poisoned: {}",
e
)))
})?;
Ok(styles.get(&type_id).map(|entry| {
entry
.as_style()
.prompt_for_field(field_name, field_type, context)
}))
}
}
#[derive(Clone, Default)]
pub struct ElicitationContext {
stack: Arc<RwLock<Vec<TypeMetadata>>>,
}
impl ElicitationContext {
pub fn push(&self, metadata: TypeMetadata) -> ElicitResult<()> {
let mut stack = self.stack.write().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"ElicitationContext lock poisoned: {}",
e
)))
})?;
stack.push(metadata.clone());
tracing::debug!(
type_name = metadata.type_name,
depth = stack.len(),
"Entering elicitation"
);
Ok(())
}
pub fn pop(&self) -> ElicitResult<()> {
let mut stack = self.stack.write().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"ElicitationContext lock poisoned: {}",
e
)))
})?;
if let Some(metadata) = stack.pop() {
tracing::debug!(
type_name = metadata.type_name,
depth = stack.len(),
"Exiting elicitation"
);
}
Ok(())
}
pub fn current(&self) -> ElicitResult<Option<TypeMetadata>> {
let stack = self.stack.read().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"ElicitationContext lock poisoned: {}",
e
)))
})?;
Ok(stack.last().cloned())
}
pub fn depth(&self) -> ElicitResult<usize> {
let stack = self.stack.read().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"ElicitationContext lock poisoned: {}",
e
)))
})?;
Ok(stack.len())
}
pub fn stack(&self) -> ElicitResult<Vec<TypeMetadata>> {
let stack = self.stack.read().map_err(|e| {
ElicitError::new(ElicitErrorKind::ParseError(format!(
"ElicitationContext lock poisoned: {}",
e
)))
})?;
Ok(stack.clone())
}
}