use crate::{ElicitCommunicator, ElicitError, ElicitErrorKind, ElicitResult, mcp};
use std::fmt::Display;
#[derive(Debug, Clone)]
pub struct ChoiceSet<T> {
items: Vec<T>,
prompt: Option<String>,
}
impl<T> ChoiceSet<T>
where
T: Display + Clone + PartialEq + Send + Sync + 'static,
{
pub fn new(items: Vec<T>) -> Self {
Self {
items,
prompt: None,
}
}
pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = Some(prompt.into());
self
}
pub fn items(&self) -> &[T] {
&self.items
}
pub fn filtered<F>(items: Vec<T>, filter: F) -> Self
where
F: Fn(&T) -> bool,
{
let filtered_items: Vec<T> = items.into_iter().filter(filter).collect();
Self::new(filtered_items)
}
pub fn with_filter<F>(self, filter: F) -> Self
where
F: Fn(&T) -> bool,
{
let filtered_items = self.items.into_iter().filter(filter).collect();
Self {
items: filtered_items,
prompt: self.prompt,
}
}
#[tracing::instrument(skip(communicator, self), fields(item_count = self.items.len()))]
pub async fn elicit<C: ElicitCommunicator>(self, communicator: &C) -> ElicitResult<T> {
if self.items.is_empty() {
return Err(ElicitError::new(ElicitErrorKind::Validation(
"Cannot elicit from empty choice set".to_string(),
)));
}
let labels: Vec<String> = self.items.iter().map(|item| item.to_string()).collect();
let prompt_text = self.prompt.as_deref().unwrap_or("Choose an option:");
let params = mcp::select_params(prompt_text, &labels);
let result = communicator
.call_tool(rmcp::model::CallToolRequestParams {
meta: None,
name: mcp::tool_names::elicit_select().into(),
arguments: Some(params),
task: None,
})
.await?;
let value = mcp::extract_value(result)?;
let selected_label = mcp::parse_string(value)?;
for (i, label) in labels.iter().enumerate() {
if label == &selected_label {
tracing::debug!(selected = %selected_label, "Item selected");
return Ok(self.items[i].clone());
}
}
Err(ElicitError::new(ElicitErrorKind::InvalidSelection(
selected_label,
)))
}
}