use async_trait::async_trait;
use std::io::Write as _;
use super::ax_converter::AxConverter;
use super::pipeline::{extract_ax_text_blocks, PerceptionError, PerceptionPipeline};
use super::signals::SignalDetector;
use super::ui_map::{ElementSource, TextBlock, UiElement, UiMap};
use crate::models::{A11yNode, Bounds, Viewport};
#[derive(Debug, Clone)]
struct OcrRegion {
text: String,
bounds: Bounds,
confidence: f32,
}
pub struct VisionPerceptionPipeline {
converter: AxConverter,
signal_detector: SignalDetector,
}
impl VisionPerceptionPipeline {
pub fn new() -> Self {
Self {
converter: AxConverter::new(),
signal_detector: SignalDetector::new(),
}
}
}
impl Default for VisionPerceptionPipeline {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PerceptionPipeline for VisionPerceptionPipeline {
async fn perceive(
&self,
screenshot: &[u8],
a11y_nodes: &[A11yNode],
url: &str,
viewport: Viewport,
) -> Result<UiMap, PerceptionError> {
let mut elements = self.converter.convert(a11y_nodes);
let mut text_blocks = extract_ax_text_blocks(a11y_nodes);
let page_signals = self.signal_detector.detect(a11y_nodes);
if !screenshot.is_empty() && car_vision::is_available() {
let bytes = screenshot.to_vec();
let (cw, ch) = css_viewport_from_ax(a11y_nodes, viewport);
let ocr = tokio::task::spawn_blocking(move || run_ocr_blocking(&bytes, cw, ch)).await;
match ocr {
Ok(Ok(regions)) if !regions.is_empty() => {
merge_ocr(&mut elements, &mut text_blocks, ®ions);
}
Ok(Ok(_)) => {}
Ok(Err(e)) => tracing::debug!(error = %e, "OCR augmentation skipped"),
Err(e) => tracing::debug!(error = %e, "OCR task join failed"),
}
}
Ok(UiMap::new(
url.to_string(),
elements,
text_blocks,
page_signals,
viewport,
String::new(),
))
}
}
fn css_viewport_from_ax(nodes: &[A11yNode], viewport: Viewport) -> (f64, f64) {
if let Some(root) = nodes.iter().find(|n| {
let r = n.role.to_lowercase();
(r.contains("webarea") || r == "rootwebarea")
&& n.bounds.width > 0.0
&& n.bounds.height > 0.0
}) {
return (root.bounds.width, root.bounds.height);
}
let max_x = nodes
.iter()
.map(|n| n.bounds.x + n.bounds.width)
.fold(0.0_f64, f64::max);
let max_y = nodes
.iter()
.map(|n| n.bounds.y + n.bounds.height)
.fold(0.0_f64, f64::max);
if max_x > 0.0 && max_y > 0.0 {
(max_x, max_y)
} else {
(viewport.width as f64, viewport.height as f64)
}
}
fn run_ocr_blocking(screenshot: &[u8], cw: f64, ch: f64) -> Result<Vec<OcrRegion>, String> {
let mut tmp = tempfile::Builder::new()
.suffix(".png")
.tempfile()
.map_err(|e| format!("temp file: {e}"))?;
tmp.write_all(screenshot)
.map_err(|e| format!("write screenshot: {e}"))?;
tmp.flush().map_err(|e| format!("flush: {e}"))?;
let config = car_vision::ocr::OcrConfig {
fast_path: true,
languages: Vec::new(),
language_correction: true,
minimum_text_height: 0.0,
};
let observations =
car_vision::ocr::recognize(tmp.path(), &config).map_err(|e| format!("ocr: {e}"))?;
Ok(observations
.into_iter()
.filter(|o| !o.text.trim().is_empty() && o.w > 0.0 && o.h > 0.0)
.map(|o| OcrRegion {
text: o.text.trim().to_string(),
bounds: Bounds::new(o.x * cw, (1.0 - o.y - o.h) * ch, o.w * cw, o.h * ch),
confidence: o.confidence,
})
.collect())
}
fn merge_ocr(elements: &mut [UiElement], text_blocks: &mut Vec<TextBlock>, regions: &[OcrRegion]) {
let mut known: Vec<String> = Vec::new();
for el in elements.iter() {
if let Some(n) = el.name.as_deref() {
known.push(n.trim().to_lowercase());
}
}
for tb in text_blocks.iter() {
known.push(tb.text.trim().to_lowercase());
}
for region in regions {
const MIN_CONTAINMENT: f64 = 0.60;
let target = elements
.iter_mut()
.filter(|el| {
el.role.is_interactable()
&& el.is_interactable()
&& el.name.as_deref().map(str::trim).unwrap_or("").is_empty()
&& containment_ratio(®ion.bounds, &el.bounds) >= MIN_CONTAINMENT
})
.min_by(|a, b| {
let area = |b: &Bounds| b.width * b.height;
area(&a.bounds)
.partial_cmp(&area(&b.bounds))
.unwrap_or(std::cmp::Ordering::Equal)
});
if let Some(el) = target {
el.name = Some(region.text.clone());
el.source = ElementSource::Merged {
sources: vec![ElementSource::AccessibilityTree, ElementSource::Ocr],
};
el.confidence = ElementSource::Merged { sources: Vec::new() }.base_confidence();
known.push(region.text.to_lowercase());
continue;
}
let lc = region.text.to_lowercase();
if !known.iter().any(|k| k == &lc) {
text_blocks.push(TextBlock::from_ocr(
region.text.clone(),
region.bounds,
region.confidence,
));
known.push(lc);
}
}
}
fn containment_ratio(region: &Bounds, el: &Bounds) -> f64 {
let ix = region.x.max(el.x);
let iy = region.y.max(el.y);
let ix2 = (region.x + region.width).min(el.x + el.width);
let iy2 = (region.y + region.height).min(el.y + el.height);
let inter = (ix2 - ix).max(0.0) * (iy2 - iy).max(0.0);
let region_area = region.width * region.height;
if region_area <= 0.0 {
0.0
} else {
inter / region_area
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::perception::ui_map::{TextSource, UiRole, UiState};
fn nameless_button(id: &str, b: Bounds) -> UiElement {
UiElement {
id: id.to_string(),
role: UiRole::Button,
name: None,
value: None,
bounds: b,
states: UiState::enabled(),
confidence: 0.9,
source: ElementSource::AccessibilityTree,
icon_type: None,
children: vec![],
ax_ref: Some(format!("ax-{id}")),
}
}
#[test]
fn label_recovery_fills_nameless_element_and_marks_merged() {
let mut els = vec![nameless_button("el_0", Bounds::new(100.0, 100.0, 80.0, 30.0))];
let mut tbs: Vec<TextBlock> = vec![];
let regions = vec![OcrRegion {
text: "Submit".to_string(),
bounds: Bounds::new(110.0, 105.0, 60.0, 20.0),
confidence: 0.95,
}];
merge_ocr(&mut els, &mut tbs, ®ions);
assert_eq!(els[0].name.as_deref(), Some("Submit"));
assert!(matches!(els[0].source, ElementSource::Merged { .. }));
assert_eq!(els[0].ax_ref.as_deref(), Some("ax-el_0"));
assert!(tbs.is_empty());
}
#[test]
fn invisible_text_becomes_ocr_text_block() {
let mut els: Vec<UiElement> = vec![];
let mut tbs: Vec<TextBlock> = vec![];
let regions = vec![OcrRegion {
text: "Score: 42".to_string(),
bounds: Bounds::new(500.0, 20.0, 90.0, 18.0),
confidence: 0.88,
}];
merge_ocr(&mut els, &mut tbs, ®ions);
assert_eq!(tbs.len(), 1);
assert_eq!(tbs[0].text, "Score: 42");
assert_eq!(tbs[0].source, TextSource::Ocr);
}
#[test]
fn ocr_duplicate_of_ax_text_is_dropped() {
let mut els = vec![UiElement {
name: Some("Welcome back".to_string()),
..nameless_button("el_0", Bounds::new(0.0, 0.0, 300.0, 40.0))
}];
let mut tbs: Vec<TextBlock> = vec![];
let regions = vec![OcrRegion {
text: "Welcome back".to_string(),
bounds: Bounds::new(800.0, 800.0, 100.0, 20.0), confidence: 0.9,
}];
merge_ocr(&mut els, &mut tbs, ®ions);
assert!(tbs.is_empty(), "duplicate of an AX name must be dropped");
}
#[test]
fn css_viewport_prefers_rootwebarea_over_passed_viewport() {
use crate::models::A11yNode;
let nodes = vec![A11yNode {
node_id: "root".into(),
role: "RootWebArea".into(),
name: None,
value: None,
bounds: Bounds::new(0.0, 0.0, 800.0, 600.0),
children: vec![],
focusable: false,
focused: false,
disabled: false,
}];
let vp = Viewport {
width: 1280,
height: 720,
device_pixel_ratio: 1.0,
};
assert_eq!(css_viewport_from_ax(&nodes, vp), (800.0, 600.0));
}
#[test]
fn css_viewport_falls_back_to_max_extent_then_viewport() {
use crate::models::A11yNode;
let nodes = vec![A11yNode {
node_id: "b".into(),
role: "button".into(),
name: Some("x".into()),
value: None,
bounds: Bounds::new(10.0, 20.0, 100.0, 30.0), children: vec![],
focusable: true,
focused: false,
disabled: false,
}];
let vp = Viewport {
width: 1280,
height: 720,
device_pixel_ratio: 1.0,
};
assert_eq!(css_viewport_from_ax(&nodes, vp), (110.0, 50.0));
assert_eq!(css_viewport_from_ax(&[], vp), (1280.0, 720.0));
}
#[test]
fn named_element_is_not_relabeled() {
let mut els = vec![UiElement {
name: Some("Sign in".to_string()),
..nameless_button("el_0", Bounds::new(100.0, 100.0, 80.0, 30.0))
}];
let mut tbs: Vec<TextBlock> = vec![];
let regions = vec![OcrRegion {
text: "garbled ocr".to_string(),
bounds: Bounds::new(120.0, 108.0, 40.0, 14.0),
confidence: 0.4,
}];
merge_ocr(&mut els, &mut tbs, ®ions);
assert_eq!(els[0].name.as_deref(), Some("Sign in"));
assert!(matches!(els[0].source, ElementSource::AccessibilityTree));
}
}