use async_trait::async_trait;
use thiserror::Error;
use super::ax_converter::AxConverter;
use super::signals::SignalDetector;
use super::ui_map::{TextBlock, UiMap};
use crate::models::{A11yNode, Viewport};
#[derive(Error, Debug)]
pub enum PerceptionError {
#[error("Conversion failed: {0}")]
ConversionFailed(String),
#[error("Signal detection failed: {0}")]
SignalDetectionFailed(String),
}
#[async_trait]
pub trait PerceptionPipeline: Send + Sync {
async fn perceive(
&self,
screenshot: &[u8],
a11y_nodes: &[A11yNode],
url: &str,
viewport: Viewport,
) -> Result<UiMap, PerceptionError>;
}
pub struct BasicPerceptionPipeline {
converter: AxConverter,
signal_detector: SignalDetector,
}
impl BasicPerceptionPipeline {
pub fn new() -> Self {
Self {
converter: AxConverter::new(),
signal_detector: SignalDetector::new(),
}
}
}
impl Default for BasicPerceptionPipeline {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PerceptionPipeline for BasicPerceptionPipeline {
async fn perceive(
&self,
_screenshot: &[u8],
a11y_nodes: &[A11yNode],
url: &str,
viewport: Viewport,
) -> Result<UiMap, PerceptionError> {
let elements = self.converter.convert(a11y_nodes);
let text_blocks = self.extract_text_blocks(a11y_nodes);
let page_signals = self.signal_detector.detect(a11y_nodes);
Ok(UiMap::new(
url.to_string(),
elements,
text_blocks,
page_signals,
viewport,
String::new(),
))
}
}
impl BasicPerceptionPipeline {
fn extract_text_blocks(&self, nodes: &[A11yNode]) -> Vec<TextBlock> {
nodes
.iter()
.filter(|n| {
let role = n.role.to_lowercase();
matches!(role.as_str(), "statictext" | "label" | "heading")
&& n.name.is_some()
&& n.bounds.width > 0.0
&& n.bounds.height > 0.0
})
.map(|n| TextBlock::from_ax(n.name.clone().unwrap_or_default(), n.bounds))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::Bounds;
#[tokio::test]
async fn test_basic_pipeline() {
let pipeline = BasicPerceptionPipeline::new();
let nodes = vec![
A11yNode {
node_id: "n0".to_string(),
role: "button".to_string(),
name: Some("Submit".to_string()),
value: None,
bounds: Bounds::new(100.0, 100.0, 80.0, 30.0),
children: vec![],
focusable: true,
focused: false,
disabled: false,
},
A11yNode {
node_id: "n1".to_string(),
role: "statictext".to_string(),
name: Some("Welcome".to_string()),
value: None,
bounds: Bounds::new(0.0, 0.0, 200.0, 20.0),
children: vec![],
focusable: false,
focused: false,
disabled: false,
},
];
let viewport = Viewport {
width: 1280,
height: 720,
device_pixel_ratio: 2.0,
};
let ui_map = pipeline
.perceive(&[], &nodes, "https://example.com", viewport)
.await
.unwrap();
assert_eq!(ui_map.elements.len(), 2);
assert_eq!(ui_map.text_blocks.len(), 1);
assert_eq!(ui_map.text_blocks[0].text, "Welcome");
}
}