use crate::traits::BrowserClient;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActionModel {
pub action_type: String,
pub params: HashMap<String, serde_json::Value>,
}
impl ActionModel {
pub fn get_index(&self) -> Option<u32> {
self.params.values().find_map(|v| {
if let Some(obj) = v.as_object() {
obj.get("index")?.as_u64().map(|i| i as u32)
} else {
None
}
})
}
pub fn set_index(&mut self, index: u32) {
for value in self.params.values_mut() {
if let Some(obj) = value.as_object_mut() {
if obj.contains_key("index") {
obj.insert("index".to_string(), serde_json::Value::Number(index.into()));
return;
}
}
}
}
}
#[async_trait::async_trait]
pub trait ActionHandler: Send + Sync {
async fn execute(
&self,
params: &ActionParams,
context: &mut ActionContext<'_>,
) -> crate::error::Result<crate::agent::views::ActionResult>;
}
pub struct ActionContext<'a> {
pub browser: &'a mut dyn BrowserClient,
pub selector_map: Option<&'a HashMap<u32, crate::dom::views::DOMInteractedElement>>,
}
pub struct ActionParams<'a> {
params: &'a HashMap<String, serde_json::Value>,
action_type: Option<String>,
}
impl<'a> ActionParams<'a> {
pub fn new(params: &'a HashMap<String, serde_json::Value>) -> Self {
Self {
params,
action_type: None,
}
}
pub fn with_action_type(mut self, action_type: String) -> Self {
self.action_type = Some(action_type);
self
}
pub fn get_action_type(&self) -> Option<&str> {
self.action_type.as_deref()
}
pub fn get_required_u32(&self, key: &str) -> crate::error::Result<u32> {
self.params
.get(key)
.and_then(|v| v.as_u64())
.map(|i| i as u32)
.ok_or_else(|| {
crate::error::BrowsingError::Tool(format!("Missing '{}' parameter", key))
})
}
pub fn get_required_str(&self, key: &str) -> crate::error::Result<&str> {
self.params
.get(key)
.and_then(|v| v.as_str())
.ok_or_else(|| {
crate::error::BrowsingError::Tool(format!("Missing '{}' parameter", key))
})
}
pub fn get_optional_bool(&self, key: &str) -> bool {
self.params
.get(key)
.and_then(|v| v.as_bool())
.unwrap_or(false)
}
pub fn get_optional_f64(&self, key: &str) -> Option<f64> {
self.params.get(key)?.as_f64()
}
pub fn get_optional_u64(&self, key: &str) -> Option<u64> {
self.params.get(key)?.as_u64()
}
pub fn inner(&self) -> &HashMap<String, serde_json::Value> {
self.params
}
pub fn backend_node_id_from_index(
&self,
index: u32,
selector_map: Option<&HashMap<u32, crate::dom::views::DOMInteractedElement>>,
) -> u32 {
if let Some(map) = selector_map {
if let Some(element) = map.get(&index) {
return element.backend_node_id.unwrap_or(index);
}
}
index
}
}
#[derive(Clone)]
pub struct RegisteredAction {
pub name: String,
pub description: String,
pub domains: Option<Vec<String>>,
pub handler: Option<std::sync::Arc<dyn ActionHandler>>,
}
impl std::fmt::Debug for RegisteredAction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RegisteredAction")
.field("name", &self.name)
.field("description", &self.description)
.field("domains", &self.domains)
.field(
"handler",
&if self.handler.is_some() {
"Some(handler)"
} else {
"None"
},
)
.finish()
}
}
impl RegisteredAction {
pub fn prompt_description(&self) -> String {
format!("{}: {}", self.name, self.description)
}
}
#[derive(Debug, Clone, Default)]
pub struct ActionRegistry {
pub actions: HashMap<String, RegisteredAction>,
}
impl ActionRegistry {
pub fn new() -> Self {
Self {
actions: HashMap::new(),
}
}
pub fn _match_domains(domains: &Option<Vec<String>>, url: &str) -> bool {
if domains.is_none() || url.is_empty() {
return true;
}
let domains = domains.as_ref().unwrap();
for domain_pattern in domains {
if crate::utils::match_url_with_domain_pattern(url, domain_pattern) {
return true;
}
}
false
}
pub fn get_prompt_description(&self, page_url: Option<&str>) -> String {
if page_url.is_none() {
return self
.actions
.values()
.filter(|action| action.domains.is_none())
.map(|action| action.prompt_description())
.collect::<Vec<_>>()
.join("\n");
}
let page_url = page_url.unwrap();
self.actions
.values()
.filter(|action| {
if action.domains.is_none() {
return false; }
Self::_match_domains(&action.domains, page_url)
})
.map(|action| action.prompt_description())
.collect::<Vec<_>>()
.join("\n")
}
}