use super::ValidationError;
use crate::{ElicitCommunicator, ElicitResult, Elicitation, Prompt};
use elicitation_macros::instrumented_impl;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
macro_rules! impl_float_default_wrapper {
($primitive:ty, $wrapper:ident) => {
#[doc = concat!("Default wrapper for ", stringify!($primitive), " (unconstrained).")]
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)]
#[schemars(description = concat!(stringify!($primitive), " value"))]
pub struct $wrapper(
#[schemars(description = "Float value")]
$primitive
);
rmcp::elicit_safe!($wrapper);
impl $wrapper {
pub fn new(value: $primitive) -> Self {
Self(value)
}
pub fn get(&self) -> $primitive {
self.0
}
pub fn into_inner(self) -> $primitive {
self.0
}
}
paste::paste! {
crate::default_style!($wrapper => [<$wrapper Style>]);
impl Prompt for $wrapper {
fn prompt() -> Option<&'static str> {
Some("Please enter a number:")
}
}
impl Elicitation for $wrapper {
type Style = [<$wrapper Style>];
#[tracing::instrument(skip(communicator))]
async fn elicit<C: ElicitCommunicator>(communicator: &C) -> ElicitResult<Self> {
let prompt = Self::prompt().unwrap();
tracing::debug!(concat!("Eliciting ", stringify!($wrapper), " with server-side send_prompt"));
let response = communicator.send_prompt(prompt).await?;
let value: $primitive = response.trim().parse().map_err(|e| {
crate::ElicitError::new(crate::ElicitErrorKind::ParseError(
format!("Failed to parse {}: {}", stringify!($primitive), e),
))
})?;
Ok(Self::new(value))
}
}
}
};
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct F32Positive(f32);
#[cfg_attr(not(kani), instrumented_impl)]
impl F32Positive {
pub fn new(value: f32) -> Result<Self, ValidationError> {
#[cfg(kani)]
{
let is_finite: bool = kani::any();
let is_positive: bool = kani::any();
if !is_finite {
Err(ValidationError::NotFinite(String::new()))
} else if is_positive {
Ok(Self(value))
} else {
Err(ValidationError::FloatNotPositive(value as f64))
}
}
#[cfg(not(kani))]
{
if !value.is_finite() {
Err(ValidationError::NotFinite(format!("{}", value)))
} else if value > 0.0 {
Ok(Self(value))
} else {
Err(ValidationError::FloatNotPositive(value as f64))
}
}
}
pub fn get(&self) -> f32 {
self.0
}
pub fn into_inner(self) -> f32 {
self.0
}
}
crate::default_style!(F32Positive => F32PositiveStyle);
impl Prompt for F32Positive {
fn prompt() -> Option<&'static str> {
Some("Please enter a positive number (> 0.0):")
}
}
impl Elicitation for F32Positive {
type Style = F32PositiveStyle;
#[tracing::instrument(skip(communicator), fields(type_name = "F32Positive"))]
async fn elicit<C: ElicitCommunicator>(communicator: &C) -> ElicitResult<Self> {
tracing::debug!("Eliciting F32Positive (positive f32 value)");
loop {
let value = f32::elicit(communicator).await?;
match Self::new(value) {
Ok(positive) => {
tracing::debug!(value, "Valid F32Positive constructed");
return Ok(positive);
}
Err(e) => {
tracing::warn!(value, error = %e, "Invalid F32Positive, re-prompting");
}
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct F32NonNegative(f32);
#[cfg_attr(not(kani), instrumented_impl)]
impl F32NonNegative {
pub fn new(value: f32) -> Result<Self, ValidationError> {
#[cfg(kani)]
{
let is_finite: bool = kani::any();
let is_non_negative: bool = kani::any();
if !is_finite {
Err(ValidationError::NotFinite(String::new()))
} else if is_non_negative {
Ok(Self(value))
} else {
Err(ValidationError::FloatNegative(value as f64))
}
}
#[cfg(not(kani))]
{
if !value.is_finite() {
Err(ValidationError::NotFinite(format!("{}", value)))
} else if value >= 0.0 {
Ok(Self(value))
} else {
Err(ValidationError::FloatNegative(value as f64))
}
}
}
pub fn get(&self) -> f32 {
self.0
}
pub fn into_inner(self) -> f32 {
self.0
}
}
crate::default_style!(F32NonNegative => F32NonNegativeStyle);
impl Prompt for F32NonNegative {
fn prompt() -> Option<&'static str> {
Some("Please enter a non-negative number (>= 0.0):")
}
}
impl Elicitation for F32NonNegative {
type Style = F32NonNegativeStyle;
#[tracing::instrument(skip(communicator), fields(type_name = "F32NonNegative"))]
async fn elicit<C: ElicitCommunicator>(communicator: &C) -> ElicitResult<Self> {
tracing::debug!("Eliciting F32NonNegative (non-negative f32 value)");
loop {
let value = f32::elicit(communicator).await?;
match Self::new(value) {
Ok(non_negative) => {
tracing::debug!(value, "Valid F32NonNegative constructed");
return Ok(non_negative);
}
Err(e) => {
tracing::warn!(value, error = %e, "Invalid F32NonNegative, re-prompting");
}
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct F32Finite(f32);
#[cfg_attr(not(kani), instrumented_impl)]
impl F32Finite {
pub fn new(value: f32) -> Result<Self, ValidationError> {
#[cfg(kani)]
{
let is_finite: bool = kani::any();
if is_finite {
Ok(Self(value))
} else {
Err(ValidationError::NotFinite(String::new()))
}
}
#[cfg(not(kani))]
{
if value.is_finite() {
Ok(Self(value))
} else {
Err(ValidationError::NotFinite(format!("{}", value)))
}
}
}
pub fn get(&self) -> f32 {
self.0
}
pub fn into_inner(self) -> f32 {
self.0
}
}
crate::default_style!(F32Finite => F32FiniteStyle);
impl Prompt for F32Finite {
fn prompt() -> Option<&'static str> {
Some("Please enter a finite number (not NaN or infinite):")
}
}
impl Elicitation for F32Finite {
type Style = F32FiniteStyle;
#[tracing::instrument(skip(communicator), fields(type_name = "F32Finite"))]
async fn elicit<C: ElicitCommunicator>(communicator: &C) -> ElicitResult<Self> {
tracing::debug!("Eliciting F32Finite (finite f32 value)");
loop {
let value = f32::elicit(communicator).await?;
match Self::new(value) {
Ok(finite) => {
tracing::debug!(value, "Valid F32Finite constructed");
return Ok(finite);
}
Err(e) => {
tracing::warn!(value, error = %e, "Invalid F32Finite, re-prompting");
}
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct F64Positive(f64);
#[cfg_attr(not(kani), instrumented_impl)]
impl F64Positive {
pub fn new(value: f64) -> Result<Self, ValidationError> {
#[cfg(kani)]
{
let is_finite: bool = kani::any();
let is_positive: bool = kani::any();
if !is_finite {
Err(ValidationError::NotFinite(String::new()))
} else if is_positive {
Ok(Self(value))
} else {
Err(ValidationError::FloatNotPositive(value))
}
}
#[cfg(not(kani))]
{
if !value.is_finite() {
Err(ValidationError::NotFinite(format!("{}", value)))
} else if value > 0.0 {
Ok(Self(value))
} else {
Err(ValidationError::FloatNotPositive(value))
}
}
}
pub fn get(&self) -> f64 {
self.0
}
pub fn into_inner(self) -> f64 {
self.0
}
}
crate::default_style!(F64Positive => F64PositiveStyle);
impl Prompt for F64Positive {
fn prompt() -> Option<&'static str> {
Some("Please enter a positive number (> 0.0):")
}
}
impl Elicitation for F64Positive {
type Style = F64PositiveStyle;
#[tracing::instrument(skip(communicator), fields(type_name = "F64Positive"))]
async fn elicit<C: ElicitCommunicator>(communicator: &C) -> ElicitResult<Self> {
tracing::debug!("Eliciting F64Positive (positive f64 value)");
loop {
let value = f64::elicit(communicator).await?;
match Self::new(value) {
Ok(positive) => {
tracing::debug!(value, "Valid F64Positive constructed");
return Ok(positive);
}
Err(e) => {
tracing::warn!(value, error = %e, "Invalid F64Positive, re-prompting");
}
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct F64NonNegative(f64);
#[cfg_attr(not(kani), instrumented_impl)]
impl F64NonNegative {
pub fn new(value: f64) -> Result<Self, ValidationError> {
#[cfg(kani)]
{
let is_finite: bool = kani::any();
let is_non_negative: bool = kani::any();
if !is_finite {
Err(ValidationError::NotFinite(String::new()))
} else if is_non_negative {
Ok(Self(value))
} else {
Err(ValidationError::FloatNegative(value))
}
}
#[cfg(not(kani))]
{
if !value.is_finite() {
Err(ValidationError::NotFinite(format!("{}", value)))
} else if value >= 0.0 {
Ok(Self(value))
} else {
Err(ValidationError::FloatNegative(value))
}
}
}
pub fn get(&self) -> f64 {
self.0
}
pub fn into_inner(self) -> f64 {
self.0
}
}
crate::default_style!(F64NonNegative => F64NonNegativeStyle);
impl Prompt for F64NonNegative {
fn prompt() -> Option<&'static str> {
Some("Please enter a non-negative number (>= 0.0):")
}
}
impl Elicitation for F64NonNegative {
type Style = F64NonNegativeStyle;
#[tracing::instrument(skip(communicator), fields(type_name = "F64NonNegative"))]
async fn elicit<C: ElicitCommunicator>(communicator: &C) -> ElicitResult<Self> {
tracing::debug!("Eliciting F64NonNegative (non-negative f64 value)");
loop {
let value = f64::elicit(communicator).await?;
match Self::new(value) {
Ok(non_negative) => {
tracing::debug!(value, "Valid F64NonNegative constructed");
return Ok(non_negative);
}
Err(e) => {
tracing::warn!(value, error = %e, "Invalid F64NonNegative, re-prompting");
}
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct F64Finite(f64);
#[cfg_attr(not(kani), instrumented_impl)]
impl F64Finite {
pub fn new(value: f64) -> Result<Self, ValidationError> {
if value.is_finite() {
Ok(Self(value))
} else {
Err(ValidationError::NotFinite(format!("{}", value)))
}
}
pub fn get(&self) -> f64 {
self.0
}
pub fn into_inner(self) -> f64 {
self.0
}
}
crate::default_style!(F64Finite => F64FiniteStyle);
impl Prompt for F64Finite {
fn prompt() -> Option<&'static str> {
Some("Please enter a finite number (not NaN or infinite):")
}
}
impl Elicitation for F64Finite {
type Style = F64FiniteStyle;
#[tracing::instrument(skip(communicator), fields(type_name = "F64Finite"))]
async fn elicit<C: ElicitCommunicator>(communicator: &C) -> ElicitResult<Self> {
tracing::debug!("Eliciting F64Finite (finite f64 value)");
loop {
let value = f64::elicit(communicator).await?;
match Self::new(value) {
Ok(finite) => {
tracing::debug!(value, "Valid F64Finite constructed");
return Ok(finite);
}
Err(e) => {
tracing::warn!(value, error = %e, "Invalid F64Finite, re-prompting");
}
}
}
}
}
#[cfg(test)]
mod f32_positive_tests {
use super::*;
#[test]
fn f32_positive_new_valid() {
let result = F32Positive::new(1.5);
assert!(result.is_ok());
assert_eq!(result.unwrap().get(), 1.5);
}
#[test]
fn f32_positive_new_zero_invalid() {
let result = F32Positive::new(0.0);
assert!(result.is_err());
}
#[test]
fn f32_positive_new_negative_invalid() {
let result = F32Positive::new(-1.5);
assert!(result.is_err());
}
#[test]
fn f32_positive_new_nan_invalid() {
let result = F32Positive::new(f32::NAN);
assert!(result.is_err());
}
#[test]
fn f32_positive_new_infinity_invalid() {
let result = F32Positive::new(f32::INFINITY);
assert!(result.is_err());
}
#[test]
fn f32_positive_into_inner() {
let positive = F32Positive::new(42.5).unwrap();
let value: f32 = positive.into_inner();
assert_eq!(value, 42.5);
}
}
#[cfg(test)]
mod f32_nonnegative_tests {
use super::*;
#[test]
fn f32_nonnegative_new_valid_positive() {
let result = F32NonNegative::new(1.5);
assert!(result.is_ok());
assert_eq!(result.unwrap().get(), 1.5);
}
#[test]
fn f32_nonnegative_new_valid_zero() {
let result = F32NonNegative::new(0.0);
assert!(result.is_ok());
assert_eq!(result.unwrap().get(), 0.0);
}
#[test]
fn f32_nonnegative_new_negative_invalid() {
let result = F32NonNegative::new(-1.5);
assert!(result.is_err());
}
#[test]
fn f32_nonnegative_new_nan_invalid() {
let result = F32NonNegative::new(f32::NAN);
assert!(result.is_err());
}
#[test]
fn f32_nonnegative_into_inner() {
let non_neg = F32NonNegative::new(10.5).unwrap();
let value: f32 = non_neg.into_inner();
assert_eq!(value, 10.5);
}
}
#[cfg(test)]
mod f32_finite_tests {
use super::*;
#[test]
fn f32_finite_new_valid_positive() {
let result = F32Finite::new(1.5);
assert!(result.is_ok());
assert_eq!(result.unwrap().get(), 1.5);
}
#[test]
fn f32_finite_new_valid_negative() {
let result = F32Finite::new(-1.5);
assert!(result.is_ok());
assert_eq!(result.unwrap().get(), -1.5);
}
#[test]
fn f32_finite_new_valid_zero() {
let result = F32Finite::new(0.0);
assert!(result.is_ok());
assert_eq!(result.unwrap().get(), 0.0);
}
#[test]
fn f32_finite_new_nan_invalid() {
let result = F32Finite::new(f32::NAN);
assert!(result.is_err());
}
#[test]
fn f32_finite_new_infinity_invalid() {
let result = F32Finite::new(f32::INFINITY);
assert!(result.is_err());
}
#[test]
fn f32_finite_new_neg_infinity_invalid() {
let result = F32Finite::new(f32::NEG_INFINITY);
assert!(result.is_err());
}
#[test]
fn f32_finite_into_inner() {
let finite = F32Finite::new(42.5).unwrap();
let value: f32 = finite.into_inner();
assert_eq!(value, 42.5);
}
}
#[cfg(test)]
mod f64_positive_tests {
use super::*;
#[test]
fn f64_positive_new_valid() {
let result = F64Positive::new(1.5);
assert!(result.is_ok());
assert_eq!(result.unwrap().get(), 1.5);
}
#[test]
fn f64_positive_new_zero_invalid() {
let result = F64Positive::new(0.0);
assert!(result.is_err());
}
#[test]
fn f64_positive_new_negative_invalid() {
let result = F64Positive::new(-1.5);
assert!(result.is_err());
}
#[test]
fn f64_positive_new_nan_invalid() {
let result = F64Positive::new(f64::NAN);
assert!(result.is_err());
}
#[test]
fn f64_positive_new_infinity_invalid() {
let result = F64Positive::new(f64::INFINITY);
assert!(result.is_err());
}
#[test]
fn f64_positive_into_inner() {
let positive = F64Positive::new(42.5).unwrap();
let value: f64 = positive.into_inner();
assert_eq!(value, 42.5);
}
}
#[cfg(test)]
mod f64_nonnegative_tests {
use super::*;
#[test]
fn f64_nonnegative_new_valid_positive() {
let result = F64NonNegative::new(1.5);
assert!(result.is_ok());
assert_eq!(result.unwrap().get(), 1.5);
}
#[test]
fn f64_nonnegative_new_valid_zero() {
let result = F64NonNegative::new(0.0);
assert!(result.is_ok());
assert_eq!(result.unwrap().get(), 0.0);
}
#[test]
fn f64_nonnegative_new_negative_invalid() {
let result = F64NonNegative::new(-1.5);
assert!(result.is_err());
}
#[test]
fn f64_nonnegative_new_nan_invalid() {
let result = F64NonNegative::new(f64::NAN);
assert!(result.is_err());
}
#[test]
fn f64_nonnegative_into_inner() {
let non_neg = F64NonNegative::new(10.5).unwrap();
let value: f64 = non_neg.into_inner();
assert_eq!(value, 10.5);
}
}
#[cfg(test)]
mod f64_finite_tests {
use super::*;
#[test]
fn f64_finite_new_valid_positive() {
let result = F64Finite::new(1.5);
assert!(result.is_ok());
assert_eq!(result.unwrap().get(), 1.5);
}
#[test]
fn f64_finite_new_valid_negative() {
let result = F64Finite::new(-1.5);
assert!(result.is_ok());
assert_eq!(result.unwrap().get(), -1.5);
}
#[test]
fn f64_finite_new_valid_zero() {
let result = F64Finite::new(0.0);
assert!(result.is_ok());
assert_eq!(result.unwrap().get(), 0.0);
}
#[test]
fn f64_finite_new_nan_invalid() {
let result = F64Finite::new(f64::NAN);
assert!(result.is_err());
}
#[test]
fn f64_finite_new_infinity_invalid() {
let result = F64Finite::new(f64::INFINITY);
assert!(result.is_err());
}
#[test]
fn f64_finite_new_neg_infinity_invalid() {
let result = F64Finite::new(f64::NEG_INFINITY);
assert!(result.is_err());
}
#[test]
fn f64_finite_into_inner() {
let finite = F64Finite::new(42.5).unwrap();
let value: f64 = finite.into_inner();
assert_eq!(value, 42.5);
}
}
impl_float_default_wrapper!(f32, F32Default);
impl_float_default_wrapper!(f64, F64Default);