Skip to main content

appscale_core/
ai.rs

1//! AI Layer — IR generation, layout optimization, and training data.
2//!
3//! This module enables three AI-driven capabilities:
4//! 1. **IR Generation**: AI models produce Binary IR directly, bypassing React
5//!    reconciler when appropriate (deterministic IR makes this possible)
6//! 2. **Layout Optimization**: Analyze layout trees and suggest performance
7//!    improvements (e.g., flatten unnecessary nesting, merge redundant nodes)
8//! 3. **IR Replay for Training**: Export recorded IR sessions as structured
9//!    training data for fine-tuning layout/UI generation models
10//!
11//! The AI layer reads from `ir.rs` types, `devtools::IrRecorder`, and
12//! `layout::LayoutEngine` — it does NOT modify the core engine loop.
13
14use crate::ir::{IrBatch, IrCommand};
15use crate::layout::LayoutEngine;
16use crate::tree::{NodeId, ShadowTree};
17use crate::platform::ViewType;
18use serde::{Serialize, Deserialize};
19use std::collections::HashMap;
20
21// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
22// IR Generation from AI
23// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
24
25/// A prompt-to-IR generation request.
26/// AI models receive this context and produce an `IrBatch`.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct IrGenerationRequest {
29    /// Natural language description of the desired UI.
30    pub prompt: String,
31
32    /// Target platform (affects layout defaults and capabilities).
33    pub platform: String,
34
35    /// Screen dimensions for layout constraint.
36    pub screen_width: f32,
37    pub screen_height: f32,
38
39    /// Optional: existing tree snapshot for incremental updates.
40    pub existing_node_count: Option<u32>,
41
42    /// Optional: component palette to constrain generation.
43    pub allowed_components: Option<Vec<String>>,
44}
45
46/// Result of AI IR generation.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct IrGenerationResult {
49    /// The generated IR batch ready to apply.
50    pub batch: IrBatch,
51
52    /// Confidence score (0.0 – 1.0) from the model.
53    pub confidence: f32,
54
55    /// Warnings or notes from the generation process.
56    pub warnings: Vec<String>,
57
58    /// Token/node count stats.
59    pub stats: GenerationStats,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct GenerationStats {
64    pub nodes_created: u32,
65    pub commands_generated: u32,
66    pub generation_time_ms: f64,
67}
68
69/// Validates that a generated IR batch is well-formed.
70/// Checks: node IDs are unique, parent references are valid, required props present.
71pub fn validate_generated_batch(batch: &IrBatch) -> Vec<ValidationIssue> {
72    let mut issues = Vec::new();
73    let mut created_ids: HashMap<u64, usize> = HashMap::new();
74    let mut parented_ids: Vec<u64> = Vec::new();
75
76    for (idx, cmd) in batch.commands.iter().enumerate() {
77        match cmd {
78            IrCommand::CreateNode { id, view_type, .. } => {
79                if let Some(&prev_idx) = created_ids.get(&id.0) {
80                    issues.push(ValidationIssue {
81                        command_index: idx,
82                        severity: IssueSeverity::Error,
83                        message: format!(
84                            "Duplicate node ID {} (first created at command {})",
85                            id.0, prev_idx
86                        ),
87                    });
88                }
89                created_ids.insert(id.0, idx);
90
91                // Text nodes should have a "text" or "content" prop
92                if matches!(view_type, ViewType::Text) {
93                    // Warn only — not blocking
94                    if let IrCommand::CreateNode { props, .. } = cmd {
95                        if !props.contains_key("text") && !props.contains_key("content") {
96                            issues.push(ValidationIssue {
97                                command_index: idx,
98                                severity: IssueSeverity::Warning,
99                                message: format!("Text node {} has no 'text' or 'content' prop", id.0),
100                            });
101                        }
102                    }
103                }
104            }
105            IrCommand::AppendChild { parent, child } => {
106                if !created_ids.contains_key(&parent.0) {
107                    issues.push(ValidationIssue {
108                        command_index: idx,
109                        severity: IssueSeverity::Error,
110                        message: format!("AppendChild references unknown parent {}", parent.0),
111                    });
112                }
113                if !created_ids.contains_key(&child.0) {
114                    issues.push(ValidationIssue {
115                        command_index: idx,
116                        severity: IssueSeverity::Error,
117                        message: format!("AppendChild references unknown child {}", child.0),
118                    });
119                }
120                parented_ids.push(child.0);
121            }
122            IrCommand::SetRootNode { id } => {
123                if !created_ids.contains_key(&id.0) {
124                    issues.push(ValidationIssue {
125                        command_index: idx,
126                        severity: IssueSeverity::Error,
127                        message: format!("SetRootNode references unknown node {}", id.0),
128                    });
129                }
130            }
131            _ => {}
132        }
133    }
134
135    // Warn about orphan nodes (created but never parented and not root)
136    let has_root = batch.commands.iter().any(|c| matches!(c, IrCommand::SetRootNode { .. }));
137    if has_root {
138        for (&id, _) in &created_ids {
139            let is_root = batch.commands.iter().any(|c| matches!(c, IrCommand::SetRootNode { id: r } if r.0 == id));
140            if !is_root && !parented_ids.contains(&id) {
141                issues.push(ValidationIssue {
142                    command_index: 0,
143                    severity: IssueSeverity::Warning,
144                    message: format!("Node {} is created but never attached to the tree", id),
145                });
146            }
147        }
148    }
149
150    issues
151}
152
153/// A validation issue found in a generated IR batch.
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct ValidationIssue {
156    pub command_index: usize,
157    pub severity: IssueSeverity,
158    pub message: String,
159}
160
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
162pub enum IssueSeverity {
163    Warning,
164    Error,
165}
166
167// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
168// Layout Optimization
169// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
170
171/// A layout optimization hint for a specific node.
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct LayoutHint {
174    pub node_id: u64,
175    pub hint_type: LayoutHintType,
176    pub description: String,
177    /// Estimated performance impact (0.0 – 1.0, higher = more impactful).
178    pub impact: f32,
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub enum LayoutHintType {
183    /// Node wraps a single child with no visual changes — can be flattened.
184    UnnecessaryWrapper,
185    /// Deep nesting that could be restructured with flexbox.
186    DeepNesting { depth: u32 },
187    /// Node has fixed dimensions but uses flex — simplify.
188    OverconstrainedLayout,
189    /// Sibling nodes have identical styles — consider shared component.
190    DuplicateStyles { sibling_count: u32 },
191    /// Large flat list without recycling hints.
192    UnoptimizedList { child_count: u32 },
193}
194
195/// Analyze a shadow tree and produce layout optimization hints.
196pub fn analyze_layout(tree: &ShadowTree, layout: &LayoutEngine) -> Vec<LayoutHint> {
197    let mut hints = Vec::new();
198
199    let root = match tree.root() {
200        Some(r) => r,
201        None => return hints,
202    };
203
204    // Analysis passes
205    detect_unnecessary_wrappers(root, tree, &mut hints);
206    detect_deep_nesting(root, tree, 0, &mut hints);
207    detect_large_flat_lists(root, tree, &mut hints);
208
209    // Sort by impact (highest first)
210    hints.sort_by(|a, b| b.impact.partial_cmp(&a.impact).unwrap_or(std::cmp::Ordering::Equal));
211    hints
212}
213
214fn detect_unnecessary_wrappers(
215    node_id: NodeId,
216    tree: &ShadowTree,
217    hints: &mut Vec<LayoutHint>,
218) {
219    let node = match tree.get(node_id) {
220        Some(n) => n,
221        None => return,
222    };
223
224    // A wrapper is unnecessary if:
225    // - It's a View with exactly 1 child
226    // - It has no event handlers (checked via empty props as proxy)
227    // - It's not the root
228    if matches!(node.view_type, ViewType::Container)
229        && node.children.len() == 1
230        && node.props.is_empty()
231        && tree.root() != Some(node_id)
232    {
233        hints.push(LayoutHint {
234            node_id: node_id.0,
235            hint_type: LayoutHintType::UnnecessaryWrapper,
236            description: format!(
237                "View node {} wraps a single child with no props — consider removing",
238                node_id.0
239            ),
240            impact: 0.3,
241        });
242    }
243
244    for &child in &node.children {
245        detect_unnecessary_wrappers(child, tree, hints);
246    }
247}
248
249fn detect_deep_nesting(
250    node_id: NodeId,
251    tree: &ShadowTree,
252    depth: u32,
253    hints: &mut Vec<LayoutHint>,
254) {
255    const DEEP_THRESHOLD: u32 = 10;
256
257    if depth >= DEEP_THRESHOLD {
258        hints.push(LayoutHint {
259            node_id: node_id.0,
260            hint_type: LayoutHintType::DeepNesting { depth },
261            description: format!(
262                "Node {} is nested {} levels deep — consider flattening with flexbox",
263                node_id.0, depth
264            ),
265            impact: 0.6,
266        });
267        return; // Don't report children — the hint covers the subtree
268    }
269
270    let node = match tree.get(node_id) {
271        Some(n) => n,
272        None => return,
273    };
274
275    for &child in &node.children {
276        detect_deep_nesting(child, tree, depth + 1, hints);
277    }
278}
279
280fn detect_large_flat_lists(
281    node_id: NodeId,
282    tree: &ShadowTree,
283    hints: &mut Vec<LayoutHint>,
284) {
285    const LARGE_LIST_THRESHOLD: usize = 50;
286
287    let node = match tree.get(node_id) {
288        Some(n) => n,
289        None => return,
290    };
291
292    // A ScrollView or View with many direct children
293    if (matches!(node.view_type, ViewType::ScrollView) || matches!(node.view_type, ViewType::Container))
294        && node.children.len() > LARGE_LIST_THRESHOLD
295    {
296        hints.push(LayoutHint {
297            node_id: node_id.0,
298            hint_type: LayoutHintType::UnoptimizedList {
299                child_count: node.children.len() as u32,
300            },
301            description: format!(
302                "{:?} node {} has {} children — consider FlatList with recycling",
303                node.view_type, node_id.0, node.children.len()
304            ),
305            impact: 0.8,
306        });
307    }
308
309    for &child in &node.children {
310        detect_large_flat_lists(child, tree, hints);
311    }
312}
313
314// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
315// IR Replay for Training Data
316// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
317
318/// A training data record: an IR session annotated with metadata.
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct TrainingRecord {
321    /// Unique session ID.
322    pub session_id: String,
323
324    /// Platform this session was recorded on.
325    pub platform: String,
326
327    /// Screen dimensions at recording time.
328    pub screen_width: f32,
329    pub screen_height: f32,
330
331    /// Sequence of IR batches (from `IrRecorder`).
332    pub batches: Vec<TrainingBatch>,
333
334    /// Optional: final tree structure for supervised learning.
335    pub final_tree: Option<TreeSnapshot>,
336
337    /// Annotation tags (e.g., "login_flow", "settings_page").
338    pub tags: Vec<String>,
339}
340
341/// An IR batch within a training record.
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct TrainingBatch {
344    /// Relative timestamp within the session (ms).
345    pub offset_ms: f64,
346
347    /// The IR batch.
348    pub batch: IrBatch,
349
350    /// Optional annotation for this specific batch.
351    pub annotation: Option<String>,
352}
353
354/// Simplified tree snapshot for training data export.
355#[derive(Debug, Clone, Serialize, Deserialize)]
356pub struct TreeSnapshot {
357    pub nodes: Vec<TreeNodeSnapshot>,
358    pub root_id: Option<u64>,
359}
360
361/// A single node in a training data tree snapshot.
362#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct TreeNodeSnapshot {
364    pub id: u64,
365    pub view_type: String,
366    pub children: Vec<u64>,
367    pub prop_count: u32,
368    pub depth: u32,
369}
370
371/// Export an `IrRecorder`'s recorded batches as a training record.
372pub fn export_training_record(
373    session_id: &str,
374    platform: &str,
375    screen_width: f32,
376    screen_height: f32,
377    recorded_batches: &[crate::devtools::TimestampedBatch],
378    tree: Option<&ShadowTree>,
379    tags: Vec<String>,
380) -> TrainingRecord {
381    let batches = recorded_batches.iter().map(|tb| {
382        TrainingBatch {
383            offset_ms: tb.offset_ms,
384            batch: tb.batch.clone(),
385            annotation: None,
386        }
387    }).collect();
388
389    let final_tree = tree.map(|t| snapshot_tree_for_training(t));
390
391    TrainingRecord {
392        session_id: session_id.to_string(),
393        platform: platform.to_string(),
394        screen_width,
395        screen_height,
396        batches,
397        final_tree,
398        tags,
399    }
400}
401
402fn snapshot_tree_for_training(tree: &ShadowTree) -> TreeSnapshot {
403    let mut nodes = Vec::new();
404    let root_id = tree.root().map(|r| r.0);
405
406    if let Some(root) = tree.root() {
407        snapshot_node_recursive(root, tree, 0, &mut nodes);
408    }
409
410    TreeSnapshot { nodes, root_id }
411}
412
413fn snapshot_node_recursive(
414    node_id: NodeId,
415    tree: &ShadowTree,
416    depth: u32,
417    nodes: &mut Vec<TreeNodeSnapshot>,
418) {
419    let node = match tree.get(node_id) {
420        Some(n) => n,
421        None => return,
422    };
423
424    let children_ids: Vec<u64> = node.children.iter().map(|c| c.0).collect();
425
426    nodes.push(TreeNodeSnapshot {
427        id: node_id.0,
428        view_type: format!("{:?}", node.view_type),
429        children: children_ids.clone(),
430        prop_count: node.props.len() as u32,
431        depth,
432    });
433
434    for &child in &node.children {
435        snapshot_node_recursive(child, tree, depth + 1, nodes);
436    }
437}
438
439/// Aggregate statistics about a training record for model analysis.
440#[derive(Debug, Clone, Serialize, Deserialize)]
441pub struct TrainingStats {
442    pub total_batches: usize,
443    pub total_commands: usize,
444    pub command_histogram: HashMap<String, usize>,
445    pub session_duration_ms: f64,
446    pub avg_batch_size: f64,
447    pub unique_node_ids: usize,
448}
449
450/// Compute statistics from a training record.
451pub fn compute_training_stats(record: &TrainingRecord) -> TrainingStats {
452    let total_batches = record.batches.len();
453    let mut total_commands = 0;
454    let mut histogram: HashMap<String, usize> = HashMap::new();
455    let mut node_ids: std::collections::HashSet<u64> = std::collections::HashSet::new();
456    let mut max_offset = 0.0_f64;
457
458    for tb in &record.batches {
459        total_commands += tb.batch.commands.len();
460        if tb.offset_ms > max_offset {
461            max_offset = tb.offset_ms;
462        }
463
464        for cmd in &tb.batch.commands {
465            let cmd_type = match cmd {
466                IrCommand::CreateNode { id, .. } => { node_ids.insert(id.0); "CreateNode" }
467                IrCommand::UpdateProps { id, .. } => { node_ids.insert(id.0); "UpdateProps" }
468                IrCommand::UpdateStyle { id, .. } => { node_ids.insert(id.0); "UpdateStyle" }
469                IrCommand::AppendChild { parent, child } => {
470                    node_ids.insert(parent.0);
471                    node_ids.insert(child.0);
472                    "AppendChild"
473                }
474                IrCommand::InsertBefore { .. } => "InsertBefore",
475                IrCommand::RemoveChild { .. } => "RemoveChild",
476                IrCommand::SetRootNode { id } => { node_ids.insert(id.0); "SetRootNode" }
477            };
478            *histogram.entry(cmd_type.to_string()).or_insert(0) += 1;
479        }
480    }
481
482    TrainingStats {
483        total_batches,
484        total_commands,
485        command_histogram: histogram,
486        session_duration_ms: max_offset,
487        avg_batch_size: if total_batches > 0 { total_commands as f64 / total_batches as f64 } else { 0.0 },
488        unique_node_ids: node_ids.len(),
489    }
490}
491
492// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
493// Tests
494// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use crate::ir::IrBatch;
500    use crate::layout::LayoutStyle;
501    use crate::platform::PropValue;
502
503    fn make_batch() -> IrBatch {
504        let mut batch = IrBatch::new(1);
505        batch.push(IrCommand::CreateNode {
506            id: NodeId(1),
507            view_type: ViewType::Container,
508            props: HashMap::new(),
509            style: LayoutStyle::default(),
510        });
511        batch.push(IrCommand::CreateNode {
512            id: NodeId(2),
513            view_type: ViewType::Text,
514            props: {
515                let mut p = HashMap::new();
516                p.insert("text".to_string(), PropValue::String("Hello".to_string()));
517                p
518            },
519            style: LayoutStyle::default(),
520        });
521        batch.push(IrCommand::AppendChild {
522            parent: NodeId(1),
523            child: NodeId(2),
524        });
525        batch.push(IrCommand::SetRootNode { id: NodeId(1) });
526        batch
527    }
528
529    #[test]
530    fn validate_well_formed_batch() {
531        let batch = make_batch();
532        let issues = validate_generated_batch(&batch);
533        let errors: Vec<_> = issues.iter().filter(|i| i.severity == IssueSeverity::Error).collect();
534        assert!(errors.is_empty(), "Expected no errors: {:?}", errors);
535    }
536
537    #[test]
538    fn validate_duplicate_ids() {
539        let mut batch = IrBatch::new(1);
540        batch.push(IrCommand::CreateNode {
541            id: NodeId(1),
542            view_type: ViewType::Container,
543            props: HashMap::new(),
544            style: LayoutStyle::default(),
545        });
546        batch.push(IrCommand::CreateNode {
547            id: NodeId(1), // duplicate!
548            view_type: ViewType::Text,
549            props: HashMap::new(),
550            style: LayoutStyle::default(),
551        });
552
553        let issues = validate_generated_batch(&batch);
554        assert!(issues.iter().any(|i| i.severity == IssueSeverity::Error && i.message.contains("Duplicate")));
555    }
556
557    #[test]
558    fn validate_unknown_parent() {
559        let mut batch = IrBatch::new(1);
560        batch.push(IrCommand::CreateNode {
561            id: NodeId(1),
562            view_type: ViewType::Container,
563            props: HashMap::new(),
564            style: LayoutStyle::default(),
565        });
566        batch.push(IrCommand::AppendChild {
567            parent: NodeId(99), // unknown
568            child: NodeId(1),
569        });
570
571        let issues = validate_generated_batch(&batch);
572        assert!(issues.iter().any(|i| i.severity == IssueSeverity::Error && i.message.contains("unknown parent")));
573    }
574
575    #[test]
576    fn validate_text_node_warning() {
577        let mut batch = IrBatch::new(1);
578        batch.push(IrCommand::CreateNode {
579            id: NodeId(1),
580            view_type: ViewType::Text,
581            props: HashMap::new(), // no "text" prop
582            style: LayoutStyle::default(),
583        });
584
585        let issues = validate_generated_batch(&batch);
586        assert!(issues.iter().any(|i| i.severity == IssueSeverity::Warning && i.message.contains("no 'text'")));
587    }
588
589    #[test]
590    fn layout_analysis_empty_tree() {
591        let tree = ShadowTree::new();
592        let layout = LayoutEngine::new();
593        let hints = analyze_layout(&tree, &layout);
594        assert!(hints.is_empty());
595    }
596
597    #[test]
598    fn training_record_export() {
599        let batch = make_batch();
600        let timestamped = vec![
601            crate::devtools::TimestampedBatch {
602                offset_ms: 0.0,
603                batch: batch.clone(),
604            },
605            crate::devtools::TimestampedBatch {
606                offset_ms: 16.6,
607                batch: batch.clone(),
608            },
609        ];
610
611        let record = export_training_record(
612            "session-001",
613            "ios",
614            390.0, 844.0,
615            &timestamped,
616            None,
617            vec!["test".to_string()],
618        );
619
620        assert_eq!(record.session_id, "session-001");
621        assert_eq!(record.platform, "ios");
622        assert_eq!(record.batches.len(), 2);
623        assert_eq!(record.tags, vec!["test"]);
624        assert!(record.final_tree.is_none());
625    }
626
627    #[test]
628    fn training_stats_computation() {
629        let batch = make_batch();
630        let record = TrainingRecord {
631            session_id: "s1".to_string(),
632            platform: "web".to_string(),
633            screen_width: 1920.0,
634            screen_height: 1080.0,
635            batches: vec![
636                TrainingBatch { offset_ms: 0.0, batch: batch.clone(), annotation: None },
637                TrainingBatch { offset_ms: 16.6, batch: batch.clone(), annotation: None },
638            ],
639            final_tree: None,
640            tags: vec![],
641        };
642
643        let stats = compute_training_stats(&record);
644        assert_eq!(stats.total_batches, 2);
645        assert_eq!(stats.total_commands, 8); // 4 cmds × 2 batches
646        assert_eq!(stats.unique_node_ids, 2); // nodes 1 and 2
647        assert!((stats.avg_batch_size - 4.0).abs() < 0.001);
648        assert!((stats.session_duration_ms - 16.6).abs() < 0.001);
649        assert_eq!(*stats.command_histogram.get("CreateNode").unwrap(), 4);
650        assert_eq!(*stats.command_histogram.get("AppendChild").unwrap(), 2);
651        assert_eq!(*stats.command_histogram.get("SetRootNode").unwrap(), 2);
652    }
653
654    #[test]
655    fn ir_generation_request_serialization() {
656        let req = IrGenerationRequest {
657            prompt: "Create a login form".to_string(),
658            platform: "ios".to_string(),
659            screen_width: 390.0,
660            screen_height: 844.0,
661            existing_node_count: None,
662            allowed_components: Some(vec!["View".to_string(), "TextInput".to_string(), "Button".to_string()]),
663        };
664
665        let json = serde_json::to_string(&req).unwrap();
666        let decoded: IrGenerationRequest = serde_json::from_str(&json).unwrap();
667        assert_eq!(decoded.prompt, "Create a login form");
668        assert_eq!(decoded.allowed_components.unwrap().len(), 3);
669    }
670
671    #[test]
672    fn validation_issue_serialization() {
673        let issue = ValidationIssue {
674            command_index: 5,
675            severity: IssueSeverity::Error,
676            message: "test error".to_string(),
677        };
678
679        let json = serde_json::to_string(&issue).unwrap();
680        let decoded: ValidationIssue = serde_json::from_str(&json).unwrap();
681        assert_eq!(decoded.command_index, 5);
682        assert_eq!(decoded.severity, IssueSeverity::Error);
683    }
684}