use crate::{
ElicitCommunicator, ElicitError, ElicitErrorKind, ElicitResult, Elicitation, Generator, Prompt,
Select, mcp,
};
use std::time::Duration;
crate::default_style!(Duration => DurationStyle);
crate::default_style!(DurationGenerationMode => DurationGenerationModeStyle);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum DurationGenerationMode {
Zero,
FromSecs(u64),
FromMillis(u64),
FromMicros(u64),
FromNanos(u64),
}
impl Select for DurationGenerationMode {
fn options() -> Vec<Self> {
vec![
DurationGenerationMode::Zero,
DurationGenerationMode::FromSecs(0),
DurationGenerationMode::FromMillis(0),
DurationGenerationMode::FromMicros(0),
DurationGenerationMode::FromNanos(0),
]
}
fn labels() -> Vec<String> {
vec![
"Zero".to_string(),
"From Seconds".to_string(),
"From Milliseconds".to_string(),
"From Microseconds".to_string(),
"From Nanoseconds".to_string(),
]
}
fn from_label(label: &str) -> Option<Self> {
match label {
"Zero" => Some(DurationGenerationMode::Zero),
"From Seconds" => Some(DurationGenerationMode::FromSecs(0)),
"From Milliseconds" => Some(DurationGenerationMode::FromMillis(0)),
"From Microseconds" => Some(DurationGenerationMode::FromMicros(0)),
"From Nanoseconds" => Some(DurationGenerationMode::FromNanos(0)),
_ => None,
}
}
}
impl Prompt for DurationGenerationMode {
fn prompt() -> Option<&'static str> {
Some("How should durations be created?")
}
}
impl Elicitation for DurationGenerationMode {
type Style = DurationGenerationModeStyle;
async fn elicit<C: ElicitCommunicator>(communicator: &C) -> ElicitResult<Self> {
let params = mcp::select_params(
Self::prompt().unwrap_or("Select an option:"),
&Self::labels(),
);
let result = communicator
.call_tool(
rmcp::model::CallToolRequestParams::new(mcp::tool_names::elicit_select())
.with_arguments(params),
)
.await?;
let value = mcp::extract_value(result)?;
let label = mcp::parse_string(value)?;
let selected = Self::from_label(&label).ok_or_else(|| {
ElicitError::new(ElicitErrorKind::ParseError(
"Invalid Duration generation mode".to_string(),
))
})?;
match selected {
DurationGenerationMode::Zero => Ok(DurationGenerationMode::Zero),
DurationGenerationMode::FromSecs(_) => {
let secs = u64::elicit(communicator).await?;
Ok(DurationGenerationMode::FromSecs(secs))
}
DurationGenerationMode::FromMillis(_) => {
let millis = u64::elicit(communicator).await?;
Ok(DurationGenerationMode::FromMillis(millis))
}
DurationGenerationMode::FromMicros(_) => {
let micros = u64::elicit(communicator).await?;
Ok(DurationGenerationMode::FromMicros(micros))
}
DurationGenerationMode::FromNanos(_) => {
let nanos = u64::elicit(communicator).await?;
Ok(DurationGenerationMode::FromNanos(nanos))
}
}
}
fn kani_proof() -> proc_macro2::TokenStream {
crate::verification::proof_helpers::kani_multi_variant_enum(
"DurationGenerationMode",
"Zero",
)
}
fn verus_proof() -> proc_macro2::TokenStream {
crate::verification::proof_helpers::verus_multi_variant_enum("DurationGenerationMode")
}
fn creusot_proof() -> proc_macro2::TokenStream {
crate::verification::proof_helpers::creusot_multi_variant_enum("DurationGenerationMode")
}
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DurationGenerator {
mode: DurationGenerationMode,
}
impl DurationGenerator {
pub fn new(mode: DurationGenerationMode) -> Self {
Self { mode }
}
pub fn mode(&self) -> DurationGenerationMode {
self.mode
}
}
impl Generator for DurationGenerator {
type Target = Duration;
fn generate(&self) -> Self::Target {
match self.mode {
DurationGenerationMode::Zero => Duration::ZERO,
DurationGenerationMode::FromSecs(secs) => Duration::from_secs(secs),
DurationGenerationMode::FromMillis(millis) => Duration::from_millis(millis),
DurationGenerationMode::FromMicros(micros) => Duration::from_micros(micros),
DurationGenerationMode::FromNanos(nanos) => Duration::from_nanos(nanos),
}
}
}
impl Prompt for Duration {
fn prompt() -> Option<&'static str> {
Some("Choose how to create the duration:")
}
}
impl Elicitation for Duration {
type Style = DurationStyle;
#[tracing::instrument(skip(communicator))]
async fn elicit<C: ElicitCommunicator>(communicator: &C) -> ElicitResult<Self> {
tracing::debug!("Eliciting Duration");
let mode = DurationGenerationMode::elicit(communicator).await?;
let generator = DurationGenerator::new(mode);
let duration = generator.generate();
tracing::debug!(?duration, mode = ?mode, "Generated Duration");
Ok(duration)
}
fn kani_proof() -> proc_macro2::TokenStream {
crate::verification::proof_helpers::kani_duration()
}
fn verus_proof() -> proc_macro2::TokenStream {
crate::verification::proof_helpers::verus_duration()
}
fn creusot_proof() -> proc_macro2::TokenStream {
crate::verification::proof_helpers::creusot_duration()
}
}