Skip to main content

wisp/components/
elicitation_form.rs

1use acp_utils::notifications::{ElicitationAction, ElicitationParams, ElicitationResponse};
2use acp_utils::{
3    ConstTitle, ElicitationSchema, EnumSchema, MultiSelectEnumSchema, PrimitiveSchema, SingleSelectEnumSchema,
4};
5use tokio::sync::oneshot;
6use tui::{Checkbox, MultiSelect, NumberField, RadioSelect, SelectOption, TextField};
7use tui::{Component, Event, Form, FormField, FormFieldKind, FormMessage, Frame, ViewContext};
8
9pub enum ElicitationMessage {
10    Responded,
11}
12
13pub struct ElicitationForm {
14    pub form: Form,
15    pub(crate) response_tx: Option<oneshot::Sender<ElicitationResponse>>,
16}
17
18impl Component for ElicitationForm {
19    type Message = ElicitationMessage;
20
21    async fn on_event(&mut self, event: &Event) -> Option<Vec<Self::Message>> {
22        let outcome = self.form.on_event(event).await?;
23        if let Some(msg) = outcome.into_iter().next() {
24            match msg {
25                FormMessage::Close => {
26                    let _ = self.response_tx.take().map(|tx| tx.send(Self::decline()));
27                    return Some(vec![ElicitationMessage::Responded]);
28                }
29                FormMessage::Submit => {
30                    let response = self.confirm();
31                    let _ = self.response_tx.take().map(|tx| tx.send(response));
32                    return Some(vec![ElicitationMessage::Responded]);
33                }
34            }
35        }
36        Some(vec![])
37    }
38
39    fn render(&mut self, ctx: &ViewContext) -> Frame {
40        self.form.render(ctx)
41    }
42}
43
44impl ElicitationForm {
45    pub fn from_params(params: ElicitationParams, response_tx: oneshot::Sender<ElicitationResponse>) -> Self {
46        let fields = parse_schema(&params.schema);
47        Self { form: Form::new(params.message, fields), response_tx: Some(response_tx) }
48    }
49
50    pub fn confirm(&self) -> ElicitationResponse {
51        ElicitationResponse { action: ElicitationAction::Accept, content: Some(self.form.to_json()) }
52    }
53
54    pub fn decline() -> ElicitationResponse {
55        ElicitationResponse { action: ElicitationAction::Decline, content: None }
56    }
57}
58
59fn parse_schema(schema: &ElicitationSchema) -> Vec<FormField> {
60    let required = schema.required.as_deref().unwrap_or(&[]);
61    schema
62        .properties
63        .iter()
64        .map(|(name, prop)| {
65            let (title, description) = extract_metadata(prop);
66            FormField {
67                name: name.clone(),
68                label: title.unwrap_or_else(|| name.clone()),
69                description,
70                required: required.iter().any(|r| r == name),
71                kind: parse_field_kind(prop),
72            }
73        })
74        .collect()
75}
76
77fn parse_field_kind(prop: &PrimitiveSchema) -> FormFieldKind {
78    match prop {
79        PrimitiveSchema::Boolean(b) => FormFieldKind::Boolean(Checkbox::new(b.default.unwrap_or(false))),
80        PrimitiveSchema::Integer(_) => FormFieldKind::Number(NumberField::new(String::new(), true)),
81        PrimitiveSchema::Number(_) => FormFieldKind::Number(NumberField::new(String::new(), false)),
82        PrimitiveSchema::String(_) => FormFieldKind::Text(TextField::new(String::new())),
83        PrimitiveSchema::Enum(e) => parse_enum_field(e),
84    }
85}
86
87fn parse_enum_field(e: &EnumSchema) -> FormFieldKind {
88    match e {
89        EnumSchema::Single(s) => match s {
90            SingleSelectEnumSchema::Untitled(u) => {
91                let options = options_from_strings(&u.enum_);
92                let default_idx =
93                    u.default.as_ref().and_then(|d| options.iter().position(|o| o.value == *d)).unwrap_or(0);
94                FormFieldKind::SingleSelect(RadioSelect::new(options, default_idx))
95            }
96            SingleSelectEnumSchema::Titled(t) => {
97                let options = options_from_const_titles(&t.one_of);
98                let default_idx =
99                    t.default.as_ref().and_then(|d| options.iter().position(|o| o.value == *d)).unwrap_or(0);
100                FormFieldKind::SingleSelect(RadioSelect::new(options, default_idx))
101            }
102        },
103        EnumSchema::Multi(m) => match m {
104            MultiSelectEnumSchema::Untitled(u) => {
105                let options = options_from_strings(&u.items.enum_);
106                let defaults = u.default.as_deref().unwrap_or(&[]);
107                let selected: Vec<bool> = options.iter().map(|o| defaults.contains(&o.value)).collect();
108                FormFieldKind::MultiSelect(MultiSelect::new(options, selected))
109            }
110            MultiSelectEnumSchema::Titled(t) => {
111                let options = options_from_const_titles(&t.items.any_of);
112                let defaults = t.default.as_deref().unwrap_or(&[]);
113                let selected: Vec<bool> = options.iter().map(|o| defaults.contains(&o.value)).collect();
114                FormFieldKind::MultiSelect(MultiSelect::new(options, selected))
115            }
116        },
117        EnumSchema::Legacy(l) => {
118            let options = options_from_strings(&l.enum_);
119            FormFieldKind::SingleSelect(RadioSelect::new(options, 0))
120        }
121    }
122}
123
124fn extract_metadata(prop: &PrimitiveSchema) -> (Option<String>, Option<String>) {
125    match prop {
126        PrimitiveSchema::String(s) => {
127            (s.title.as_ref().map(ToString::to_string), s.description.as_ref().map(ToString::to_string))
128        }
129        PrimitiveSchema::Number(n) => {
130            (n.title.as_ref().map(ToString::to_string), n.description.as_ref().map(ToString::to_string))
131        }
132        PrimitiveSchema::Integer(i) => {
133            (i.title.as_ref().map(ToString::to_string), i.description.as_ref().map(ToString::to_string))
134        }
135        PrimitiveSchema::Boolean(b) => {
136            (b.title.as_ref().map(ToString::to_string), b.description.as_ref().map(ToString::to_string))
137        }
138        PrimitiveSchema::Enum(e) => extract_enum_metadata(e),
139    }
140}
141
142fn extract_enum_metadata(e: &EnumSchema) -> (Option<String>, Option<String>) {
143    match e {
144        EnumSchema::Single(s) => match s {
145            SingleSelectEnumSchema::Untitled(u) => {
146                (u.title.as_ref().map(ToString::to_string), u.description.as_ref().map(ToString::to_string))
147            }
148            SingleSelectEnumSchema::Titled(t) => {
149                (t.title.as_ref().map(ToString::to_string), t.description.as_ref().map(ToString::to_string))
150            }
151        },
152        EnumSchema::Multi(m) => match m {
153            MultiSelectEnumSchema::Untitled(u) => {
154                (u.title.as_ref().map(ToString::to_string), u.description.as_ref().map(ToString::to_string))
155            }
156            MultiSelectEnumSchema::Titled(t) => {
157                (t.title.as_ref().map(ToString::to_string), t.description.as_ref().map(ToString::to_string))
158            }
159        },
160        EnumSchema::Legacy(l) => {
161            (l.title.as_ref().map(ToString::to_string), l.description.as_ref().map(ToString::to_string))
162        }
163    }
164}
165
166fn options_from_strings(values: &[String]) -> Vec<SelectOption> {
167    values.iter().map(|s| SelectOption { value: s.clone(), title: s.clone(), description: None }).collect()
168}
169
170fn options_from_const_titles(items: &[ConstTitle]) -> Vec<SelectOption> {
171    items
172        .iter()
173        .map(|ct| SelectOption { value: ct.const_.clone(), title: ct.title.clone(), description: None })
174        .collect()
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use acp_utils::EnumSchema;
181    use std::collections::BTreeMap;
182
183    fn test_schema() -> ElicitationSchema {
184        serde_json::from_value(serde_json::json!({
185            "type": "object",
186            "properties": {
187                "name": {
188                    "type": "string",
189                    "title": "Your Name",
190                    "description": "Enter your full name"
191                },
192                "age": {
193                    "type": "integer",
194                    "title": "Age",
195                    "minimum": 0,
196                    "maximum": 150
197                },
198                "rating": {
199                    "type": "number",
200                    "title": "Rating"
201                },
202                "approved": {
203                    "type": "boolean",
204                    "title": "Approved",
205                    "default": true
206                },
207                "color": {
208                    "type": "string",
209                    "title": "Favorite Color",
210                    "enum": ["red", "green", "blue"]
211                },
212                "tags": {
213                    "type": "array",
214                    "title": "Tags",
215                    "items": {
216                        "type": "string",
217                        "enum": ["fast", "reliable", "cheap"]
218                    }
219                }
220            },
221            "required": ["name", "color"]
222        }))
223        .unwrap()
224    }
225
226    #[test]
227    fn parse_schema_extracts_all_field_types() {
228        let schema = test_schema();
229        let fields = parse_schema(&schema);
230        assert_eq!(fields.len(), 6);
231
232        let name_field = fields.iter().find(|f| f.name == "name").unwrap();
233        assert_eq!(name_field.label, "Your Name");
234        assert!(name_field.required);
235        assert!(matches!(name_field.kind, FormFieldKind::Text(_)));
236
237        let age_field = fields.iter().find(|f| f.name == "age").unwrap();
238        match &age_field.kind {
239            FormFieldKind::Number(nf) => assert!(nf.integer_only),
240            _ => panic!("Expected Number (integer)"),
241        }
242
243        let bool_field = fields.iter().find(|f| f.name == "approved").unwrap();
244        match &bool_field.kind {
245            FormFieldKind::Boolean(cb) => assert!(cb.checked),
246            _ => panic!("Expected Boolean"),
247        }
248
249        let color_field = fields.iter().find(|f| f.name == "color").unwrap();
250        assert!(color_field.required);
251        match &color_field.kind {
252            FormFieldKind::SingleSelect(rs) => {
253                assert_eq!(rs.options.len(), 3);
254                assert_eq!(rs.options[0].value, "red");
255            }
256            _ => panic!("Expected SingleSelect"),
257        }
258
259        let tags_field = fields.iter().find(|f| f.name == "tags").unwrap();
260        match &tags_field.kind {
261            FormFieldKind::MultiSelect(ms) => {
262                assert_eq!(ms.options.len(), 3);
263                assert!(ms.selected.iter().all(|&s| !s));
264            }
265            _ => panic!("Expected MultiSelect"),
266        }
267    }
268
269    #[test]
270    fn confirm_produces_correct_json() {
271        let (tx, _rx) = oneshot::channel();
272        let params = ElicitationParams {
273            message: "Test".to_string(),
274            schema: ElicitationSchema::builder()
275                .optional_string("name")
276                .optional_bool("approved", true)
277                .optional_enum_schema(
278                    "color",
279                    EnumSchema::builder(vec!["red".into(), "green".into()])
280                        .untitled()
281                        .with_default("green")
282                        .unwrap()
283                        .build(),
284                )
285                .build()
286                .unwrap(),
287        };
288
289        let form = ElicitationForm::from_params(params, tx);
290        let response = form.confirm();
291
292        assert_eq!(response.action, ElicitationAction::Accept);
293        let content = response.content.unwrap();
294        assert_eq!(content["name"], "");
295        assert_eq!(content["approved"], true);
296        assert_eq!(content["color"], "green");
297    }
298
299    #[test]
300    fn esc_returns_decline() {
301        let response = ElicitationForm::decline();
302        assert_eq!(response.action, ElicitationAction::Decline);
303        assert!(response.content.is_none());
304    }
305
306    #[test]
307    fn one_of_string_produces_single_select() {
308        let schema: ElicitationSchema = serde_json::from_value(serde_json::json!({
309            "type": "object",
310            "properties": {
311                "size": {
312                    "type": "string",
313                    "oneOf": [
314                        { "const": "s", "title": "Small" },
315                        { "const": "m", "title": "Medium" },
316                        { "const": "l", "title": "Large" }
317                    ]
318                }
319            }
320        }))
321        .unwrap();
322        let fields = parse_schema(&schema);
323        assert_eq!(fields.len(), 1);
324        match &fields[0].kind {
325            FormFieldKind::SingleSelect(rs) => {
326                assert_eq!(rs.options.len(), 3);
327                assert_eq!(rs.options[0].title, "Small");
328                assert_eq!(rs.options[0].value, "s");
329            }
330            _ => panic!("Expected SingleSelect"),
331        }
332    }
333
334    #[test]
335    fn empty_schema_produces_no_fields() {
336        let schema = ElicitationSchema::new(BTreeMap::new());
337        let fields = parse_schema(&schema);
338        assert!(fields.is_empty());
339    }
340}