Skip to main content

wisp/components/
input_prompt.rs

1use std::ops::Range;
2use tui::{Line, ViewContext, soft_wrap_text_position};
3
4pub fn prompt_content_width(terminal_width: usize) -> usize {
5    terminal_width.saturating_sub(2).max(1)
6}
7
8pub fn prompt_text_start_col(_terminal_width: usize) -> usize {
9    2
10}
11
12pub struct InputPrompt<'a> {
13    pub input: &'a str,
14    pub cursor_index: usize,
15    pub highlight_range: Option<Range<usize>>,
16}
17
18pub struct InputPromptLayout {
19    pub lines: Vec<Line>,
20    /// Cursor row within `lines` (0-based).
21    pub cursor_row: usize,
22    /// Cursor column on that row (0-based).
23    pub cursor_col: u16,
24}
25
26impl InputPrompt<'_> {
27    pub fn layout(&self, context: &ViewContext) -> InputPromptLayout {
28        let width = usize::from(context.size.width);
29        let cursor_index = clamp_to_char_boundary(self.input, self.cursor_index);
30        let styled_input = style_input(self.input, context, self.highlight_range.as_ref());
31
32        let content_width = prompt_content_width(width);
33        let content_width_u16 = u16::try_from(content_width).unwrap_or(u16::MAX);
34        let wrapped_chunks = styled_input.soft_wrap(content_width_u16);
35
36        let (cursor_content_row, cursor_content_col) = soft_wrap_text_position(self.input, cursor_index, content_width);
37
38        let content_rows = wrapped_chunks.len().max(cursor_content_row + 1);
39
40        let mut lines = Vec::with_capacity(content_rows + 2);
41        lines.push(Line::styled("─".repeat(width), context.theme.muted()));
42
43        for row in 0..content_rows {
44            let chunk = wrapped_chunks.get(row).cloned().unwrap_or_default();
45            let mut middle = Line::default();
46            if row == 0 {
47                middle.push_styled("> ", context.theme.primary());
48            } else {
49                middle.push_styled("  ", context.theme.muted());
50            }
51            middle.append_line(&chunk);
52            lines.push(middle);
53        }
54
55        lines.push(Line::styled("─".repeat(width), context.theme.muted()));
56
57        InputPromptLayout {
58            lines,
59            cursor_row: 1 + cursor_content_row,
60            cursor_col: u16::try_from(prompt_text_start_col(width) + cursor_content_col).unwrap_or(u16::MAX),
61        }
62    }
63}
64
65impl InputPrompt<'_> {
66    #[cfg(test)]
67    pub fn render(&self, context: &ViewContext) -> Vec<Line> {
68        self.layout(context).lines
69    }
70}
71
72fn clamp_to_char_boundary(input: &str, index: usize) -> usize {
73    let mut index = index.min(input.len());
74    while index > 0 && !input.is_char_boundary(index) {
75        index -= 1;
76    }
77    index
78}
79
80fn style_input(input: &str, context: &ViewContext, highlight: Option<&Range<usize>>) -> Line {
81    let highlight = highlight
82        .filter(|r| r.start < r.end && r.end <= input.len())
83        .filter(|r| input.is_char_boundary(r.start) && input.is_char_boundary(r.end));
84
85    if highlight.is_none() && !input.contains('@') {
86        return Line::styled(input, context.theme.text_primary());
87    }
88
89    let mentions = mention_ranges(input);
90    let base = context.theme.text_primary();
91    let info = context.theme.info();
92    let warning = context.theme.warning();
93
94    let color_at = |byte: usize| -> tui::Color {
95        if let Some(r) = &highlight
96            && r.contains(&byte)
97        {
98            return warning;
99        }
100        if mentions.iter().any(|m| m.contains(&byte)) {
101            return info;
102        }
103        base
104    };
105
106    let mut line = Line::default();
107    let mut run_start = 0;
108    let mut current = color_at(0);
109    for (i, _) in input.char_indices().skip(1) {
110        let c = color_at(i);
111        if c != current {
112            line.push_styled(&input[run_start..i], current);
113            run_start = i;
114            current = c;
115        }
116    }
117    line.push_styled(&input[run_start..], current);
118    line
119}
120
121fn mention_ranges(input: &str) -> Vec<Range<usize>> {
122    if !input.contains('@') {
123        return Vec::new();
124    }
125    let mut out = Vec::new();
126    let mut last_end = 0;
127    for (at_pos, _) in input.match_indices('@') {
128        if at_pos < last_end {
129            continue;
130        }
131        let end = input[at_pos..].find(char::is_whitespace).map_or(input.len(), |i| at_pos + i);
132        out.push(at_pos..end);
133        last_end = end;
134    }
135    out
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn renders_three_lines() {
144        let prompt = InputPrompt { input: "", cursor_index: 0, highlight_range: None };
145        let ctx = ViewContext::new((80, 24));
146        let lines = prompt.render(&ctx);
147        assert_eq!(lines.len(), 3);
148    }
149
150    #[test]
151    fn top_rule_is_horizontal_line() {
152        let prompt = InputPrompt { input: "", cursor_index: 0, highlight_range: None };
153        let ctx = ViewContext::new((80, 24));
154        let lines = prompt.render(&ctx);
155        assert!(lines[0].plain_text().chars().all(|c| c == '─'));
156        assert_eq!(lines[0].display_width(), 80);
157    }
158
159    #[test]
160    fn bottom_rule_is_horizontal_line() {
161        let prompt = InputPrompt { input: "", cursor_index: 0, highlight_range: None };
162        let ctx = ViewContext::new((80, 24));
163        let lines = prompt.render(&ctx);
164        assert!(lines[2].plain_text().chars().all(|c| c == '─'));
165    }
166
167    #[test]
168    fn middle_line_contains_prompt() {
169        let prompt = InputPrompt { input: "", cursor_index: 0, highlight_range: None };
170        let ctx = ViewContext::new((80, 24));
171        let lines = prompt.render(&ctx);
172        assert!(lines[1].plain_text().starts_with("> "));
173    }
174
175    #[test]
176    fn renders_input_text() {
177        let prompt = InputPrompt { input: "hello", cursor_index: 5, highlight_range: None };
178        let ctx = ViewContext::new((80, 24));
179        let lines = prompt.render(&ctx);
180        assert!(lines[1].plain_text().contains("hello"));
181    }
182
183    #[test]
184    fn renders_consistently() {
185        let prompt = InputPrompt { input: "test", cursor_index: 4, highlight_range: None };
186        let ctx = ViewContext::new((80, 24));
187        let a = prompt.render(&ctx);
188        let b = prompt.render(&ctx);
189        assert_eq!(a, b);
190    }
191
192    #[test]
193    fn adapts_to_terminal_width() {
194        let prompt = InputPrompt { input: "", cursor_index: 0, highlight_range: None };
195        let narrow = ViewContext::new((40, 24));
196        let wide = ViewContext::new((120, 24));
197        let narrow_lines = prompt.render(&narrow);
198        let wide_lines = prompt.render(&wide);
199        // Both should produce 3 lines but different widths
200        assert_eq!(narrow_lines.len(), 3);
201        assert_eq!(wide_lines.len(), 3);
202        // Wide border should be longer than narrow
203        assert!(wide_lines[0].plain_text().len() > narrow_lines[0].plain_text().len());
204    }
205
206    #[test]
207    fn wraps_long_input() {
208        let prompt = InputPrompt {
209            input: "this is a very long input that should wrap",
210            cursor_index: 41,
211            highlight_range: None,
212        };
213        let ctx = ViewContext::new((20, 24));
214        let lines = prompt.render(&ctx);
215        assert!(lines.len() > 3);
216    }
217
218    #[test]
219    fn hard_newline_renders_continuation_row() {
220        let prompt = InputPrompt { input: "one\ntwo", cursor_index: "one\ntwo".len(), highlight_range: None };
221        let ctx = ViewContext::new((80, 24));
222        let layout = prompt.layout(&ctx);
223        let line_text = layout.lines.iter().map(Line::plain_text).collect::<Vec<_>>();
224        assert_eq!(line_text, vec!["─".repeat(80), "> one".to_owned(), "  two".to_owned(), "─".repeat(80)]);
225        assert_eq!((layout.cursor_row, layout.cursor_col), (2, 5));
226    }
227
228    #[test]
229    fn cursor_after_hard_newline_starts_continuation_row() {
230        let prompt = InputPrompt { input: "one\ntwo", cursor_index: 4, highlight_range: None };
231        let ctx = ViewContext::new((80, 24));
232        let layout = prompt.layout(&ctx);
233        assert_eq!((layout.cursor_row, layout.cursor_col), (2, 2));
234    }
235
236    #[test]
237    fn cursor_after_wrapped_space_uses_rendered_row() {
238        let prompt = InputPrompt { input: "hello world", cursor_index: "hello ".len(), highlight_range: None };
239        let ctx = ViewContext::new((9, 24));
240        let layout = prompt.layout(&ctx);
241        let line_text = layout.lines.iter().map(Line::plain_text).collect::<Vec<_>>();
242
243        assert_eq!(line_text, vec!["─".repeat(9), "> hello".to_owned(), "  world".to_owned(), "─".repeat(9)]);
244        assert_eq!((layout.cursor_row, layout.cursor_col), (2, 2));
245    }
246
247    #[test]
248    fn mention_and_plain_text_both_render() {
249        let prompt = InputPrompt { input: "@main.rs explain this", cursor_index: 20, highlight_range: None };
250        let ctx = ViewContext::new((80, 24));
251        let lines = prompt.render(&ctx);
252        assert!(lines[1].plain_text().contains("@main.rs"));
253        assert!(lines[1].plain_text().contains("explain this"));
254    }
255
256    #[test]
257    fn hard_newline_terminates_mention_styling() {
258        let prompt = InputPrompt { input: "@main.rs\nhello", cursor_index: 14, highlight_range: None };
259        let ctx = ViewContext::new((80, 24));
260        let layout = prompt.layout(&ctx);
261
262        let styled_spans = layout
263            .lines
264            .iter()
265            .map(|line| line.spans().iter().map(|span| (span.text().to_owned(), span.style().fg)).collect::<Vec<_>>())
266            .collect::<Vec<_>>();
267
268        assert_eq!(
269            styled_spans,
270            vec![
271                vec![("─".repeat(80), Some(ctx.theme.muted()))],
272                vec![("> ".to_owned(), Some(ctx.theme.primary())), ("@main.rs".to_owned(), Some(ctx.theme.info()))],
273                vec![("  ".to_owned(), Some(ctx.theme.muted())), ("hello".to_owned(), Some(ctx.theme.text_primary()))],
274                vec![("─".repeat(80), Some(ctx.theme.muted()))],
275            ]
276        );
277    }
278}