use crate::cache_config::CacheConfig;
use crate::client::Client;
use crate::pair_code::PairCodeOptions;
use crate::store::commands::DeviceCommand;
use crate::store::persistence_manager::PersistenceManager;
use crate::store::traits::Backend;
use crate::types::enc_handler::EncHandler;
use crate::types::events::{Event, EventHandler};
use crate::types::message::MessageInfo;
use anyhow::Result;
use log::{info, warn};
use std::collections::HashMap;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use thiserror::Error;
use wacore::runtime::Runtime;
use wacore::store::DevicePropsOverride;
use waproto::whatsapp as wa;
pub struct Missing;
pub struct Provided;
#[derive(Debug, Error)]
pub enum BotBuilderError {
#[error(transparent)]
Other(#[from] anyhow::Error),
}
#[derive(Clone)]
pub struct MessageContext {
pub message: Arc<wa::Message>,
pub info: MessageInfo,
pub client: Arc<Client>,
}
impl MessageContext {
pub fn from_parts(message: &wa::Message, info: &MessageInfo, client: Arc<Client>) -> Self {
Self::from_arc(Arc::new(message.clone()), info, client)
}
pub fn from_arc(message: Arc<wa::Message>, info: &MessageInfo, client: Arc<Client>) -> Self {
Self {
message,
info: info.clone(),
client,
}
}
pub fn from_event(event: &Event, client: Arc<Client>) -> Option<Self> {
let (msg, info) = event.as_message()?;
Some(Self::from_arc(Arc::clone(msg), info, client))
}
pub async fn send_message(
&self,
message: wa::Message,
) -> Result<crate::send::SendResult, anyhow::Error> {
self.client
.send_message(self.info.source.chat.clone(), message)
.await
}
pub fn build_quote_context(&self) -> wa::ContextInfo {
wacore::proto_helpers::build_quote_context_with_info(
&self.info.id,
&self.info.source.sender,
&self.info.source.chat,
&self.message,
)
}
pub fn message_key(&self) -> wa::MessageKey {
use wacore_binary::JidExt;
let needs_participant =
self.info.source.is_group || self.info.source.chat.is_status_broadcast();
wa::MessageKey {
remote_jid: Some(self.info.source.chat.to_string()),
from_me: Some(self.info.source.is_from_me),
id: Some(self.info.id.clone()),
participant: needs_participant.then(|| self.info.source.sender.to_string()),
}
}
pub async fn edit_message(
&self,
original_message_id: impl Into<String>,
new_message: wa::Message,
) -> Result<String, anyhow::Error> {
self.client
.edit_message(
self.info.source.chat.clone(),
original_message_id,
new_message,
)
.await
}
pub async fn revoke_message(
&self,
message_id: String,
revoke_type: crate::send::RevokeType,
) -> Result<(), anyhow::Error> {
self.client
.revoke_message(self.info.source.chat.clone(), message_id, revoke_type)
.await
}
}
type EventHandlerCallback =
Arc<dyn Fn(Arc<Event>, Arc<Client>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
struct BotEventHandler {
client: Arc<Client>,
event_handler: Option<EventHandlerCallback>,
}
impl EventHandler for BotEventHandler {
fn handle_event(&self, event: Arc<Event>) {
if let Some(handler) = &self.event_handler {
let handler_clone = handler.clone();
let client_clone = self.client.clone();
self.client
.runtime
.spawn(Box::pin(async move {
handler_clone(event, client_clone).await;
}))
.detach();
}
}
}
pub struct BotHandle {
done_rx: futures::channel::oneshot::Receiver<()>,
_abort_handle: wacore::runtime::AbortHandle,
}
impl BotHandle {
pub fn abort(&self) {
self._abort_handle.abort();
}
}
impl std::future::Future for BotHandle {
type Output = Result<(), futures::channel::oneshot::Canceled>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
Pin::new(&mut self.done_rx).poll(cx)
}
}
pub struct Bot {
client: Arc<Client>,
sync_task_receiver: Option<async_channel::Receiver<crate::sync_task::MajorSyncTask>>,
event_handler: Option<EventHandlerCallback>,
pair_code_options: Option<PairCodeOptions>,
}
impl std::fmt::Debug for Bot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Bot")
.field("client", &"<Client>")
.field("sync_task_receiver", &self.sync_task_receiver.is_some())
.field("event_handler", &self.event_handler.is_some())
.field("pair_code_options", &self.pair_code_options.is_some())
.finish()
}
}
impl Bot {
pub fn builder() -> BotBuilder<Missing, Missing, Missing, Missing> {
BotBuilder::new()
}
pub fn client(&self) -> Arc<Client> {
self.client.clone()
}
pub async fn run(&mut self) -> Result<BotHandle> {
if let Some(receiver) = self.sync_task_receiver.take() {
let worker_client = Arc::downgrade(&self.client);
self.client
.runtime
.spawn(Box::pin(async move {
while let Ok(task) = receiver.recv().await {
let Some(worker_client) = worker_client.upgrade() else {
break;
};
worker_client.process_sync_task(task).await;
}
info!("Sync worker shutting down.");
}))
.detach();
}
let handler = Arc::new(BotEventHandler {
client: self.client.clone(),
event_handler: self.event_handler.take(),
});
self.client.core.event_bus.add_handler(handler);
if let Some(options) = self.pair_code_options.take() {
let client_for_pair = self.client.clone();
self.client.runtime.spawn(Box::pin(async move {
if let Err(e) = client_for_pair
.wait_for_socket(std::time::Duration::from_secs(30))
.await
{
warn!(target: "Bot/PairCode", "Timeout waiting for socket: {}", e);
return;
}
if client_for_pair.is_logged_in() {
info!(target: "Bot/PairCode", "Already logged in, skipping pair code request");
return;
}
match client_for_pair.pair_with_code(options).await {
Ok(code) => {
info!(target: "Bot/PairCode", "Pair code generated: {}", code);
}
Err(e) => {
warn!(target: "Bot/PairCode", "Failed to request pair code: {}", e);
}
}
})).detach();
}
let client_for_run = self.client.clone();
let (done_tx, done_rx) = futures::channel::oneshot::channel::<()>();
let abort_handle = self.client.runtime.spawn(Box::pin(async move {
client_for_run.run().await;
let _ = done_tx.send(());
}));
Ok(BotHandle {
done_rx,
_abort_handle: abort_handle,
})
}
}
pub struct BotBuilder<B = Missing, T = Missing, H = Missing, R = Missing> {
backend: Option<Arc<dyn Backend>>,
transport_factory: Option<Arc<dyn crate::transport::TransportFactory>>,
http_client: Option<Arc<dyn crate::http::HttpClient>>,
runtime: Option<Arc<dyn Runtime>>,
event_handler: Option<EventHandlerCallback>,
custom_enc_handlers: HashMap<String, Arc<dyn EncHandler>>,
override_version: Option<(u32, u32, u32)>,
device_props_override: Option<DevicePropsOverride>,
pair_code_options: Option<PairCodeOptions>,
skip_history_sync: bool,
initial_push_name: Option<String>,
cache_config: CacheConfig,
_marker: PhantomData<(B, T, H, R)>,
}
impl BotBuilder<Missing, Missing, Missing, Missing> {
fn new() -> Self {
Self {
backend: None,
transport_factory: None,
http_client: None,
runtime: None,
event_handler: None,
custom_enc_handlers: HashMap::new(),
override_version: None,
device_props_override: None,
pair_code_options: None,
skip_history_sync: false,
initial_push_name: None,
cache_config: CacheConfig::default(),
_marker: PhantomData,
}
}
}
impl<T, H, R> BotBuilder<Missing, T, H, R> {
pub fn with_backend(self, backend: Arc<dyn Backend>) -> BotBuilder<Provided, T, H, R> {
BotBuilder {
backend: Some(backend),
transport_factory: self.transport_factory,
http_client: self.http_client,
runtime: self.runtime,
event_handler: self.event_handler,
custom_enc_handlers: self.custom_enc_handlers,
override_version: self.override_version,
device_props_override: self.device_props_override,
pair_code_options: self.pair_code_options,
skip_history_sync: self.skip_history_sync,
initial_push_name: self.initial_push_name,
cache_config: self.cache_config,
_marker: PhantomData,
}
}
}
impl<B, H, R> BotBuilder<B, Missing, H, R> {
pub fn with_transport_factory<F>(self, factory: F) -> BotBuilder<B, Provided, H, R>
where
F: crate::transport::TransportFactory + 'static,
{
BotBuilder {
backend: self.backend,
transport_factory: Some(Arc::new(factory)),
http_client: self.http_client,
runtime: self.runtime,
event_handler: self.event_handler,
custom_enc_handlers: self.custom_enc_handlers,
override_version: self.override_version,
device_props_override: self.device_props_override,
pair_code_options: self.pair_code_options,
skip_history_sync: self.skip_history_sync,
initial_push_name: self.initial_push_name,
cache_config: self.cache_config,
_marker: PhantomData,
}
}
}
impl<B, T, R> BotBuilder<B, T, Missing, R> {
pub fn with_http_client<C>(self, client: C) -> BotBuilder<B, T, Provided, R>
where
C: crate::http::HttpClient + 'static,
{
BotBuilder {
backend: self.backend,
transport_factory: self.transport_factory,
http_client: Some(Arc::new(client)),
runtime: self.runtime,
event_handler: self.event_handler,
custom_enc_handlers: self.custom_enc_handlers,
override_version: self.override_version,
device_props_override: self.device_props_override,
pair_code_options: self.pair_code_options,
skip_history_sync: self.skip_history_sync,
initial_push_name: self.initial_push_name,
cache_config: self.cache_config,
_marker: PhantomData,
}
}
}
impl<B, T, H> BotBuilder<B, T, H, Missing> {
pub fn with_runtime<Rt: Runtime>(self, runtime: Rt) -> BotBuilder<B, T, H, Provided> {
BotBuilder {
backend: self.backend,
transport_factory: self.transport_factory,
http_client: self.http_client,
runtime: Some(Arc::new(runtime)),
event_handler: self.event_handler,
custom_enc_handlers: self.custom_enc_handlers,
override_version: self.override_version,
device_props_override: self.device_props_override,
pair_code_options: self.pair_code_options,
skip_history_sync: self.skip_history_sync,
initial_push_name: self.initial_push_name,
cache_config: self.cache_config,
_marker: PhantomData,
}
}
}
impl<B, T, H, R> BotBuilder<B, T, H, R> {
pub fn on_event<F, Fut>(mut self, handler: F) -> Self
where
F: Fn(Arc<Event>, Arc<Client>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
self.event_handler = Some(Arc::new(move |event, client| {
Box::pin(handler(event, client))
}));
self
}
pub fn with_enc_handler<Eh>(mut self, enc_type: impl Into<String>, handler: Eh) -> Self
where
Eh: EncHandler + 'static,
{
self.custom_enc_handlers
.insert(enc_type.into(), Arc::new(handler));
self
}
pub fn with_version(mut self, version: (u32, u32, u32)) -> Self {
self.override_version = Some(version);
self
}
pub fn with_device_props(mut self, override_: DevicePropsOverride) -> Self {
self.device_props_override = Some(override_);
self
}
pub fn with_pair_code(mut self, options: PairCodeOptions) -> Self {
self.pair_code_options = Some(options);
self
}
pub fn skip_history_sync(mut self) -> Self {
self.skip_history_sync = true;
self
}
pub fn with_push_name(mut self, name: impl Into<String>) -> Self {
self.initial_push_name = Some(name.into());
self
}
pub fn with_cache_config(mut self, config: CacheConfig) -> Self {
self.cache_config = config;
self
}
}
impl BotBuilder<Provided, Provided, Provided, Provided> {
pub async fn build(self) -> std::result::Result<Bot, BotBuilderError> {
let (Some(runtime), Some(backend), Some(transport_factory), Some(http_client)) = (
self.runtime,
self.backend,
self.transport_factory,
self.http_client,
) else {
unreachable!("typestate guarantees all required fields are Provided")
};
let persistence_manager = Arc::new(
PersistenceManager::new(backend)
.await
.map_err(|e| anyhow::anyhow!("Failed to create persistence manager: {}", e))?,
);
if let Some(name) = self.initial_push_name {
persistence_manager
.process_command(DeviceCommand::SetPushName(name))
.await;
}
if let Some(override_) = self.device_props_override
&& !override_.is_empty()
{
info!("Applying device props override: {:?}", override_);
persistence_manager
.process_command(DeviceCommand::SetDeviceProps(override_))
.await;
}
info!("Creating client...");
let (client, sync_task_receiver) = Client::new_with_cache_config(
runtime.clone(),
persistence_manager.clone(),
transport_factory,
http_client,
self.override_version,
self.cache_config,
)
.await;
let saver_handle = persistence_manager.run_background_saver(
runtime,
std::time::Duration::from_secs(30),
client.shutdown_signal(),
);
let _ = client.saver_handle.set(saver_handle);
for (enc_type, handler) in self.custom_enc_handlers {
client
.custom_enc_handlers
.write()
.await
.insert(enc_type, handler);
}
if self.skip_history_sync {
client.set_skip_history_sync(true);
}
Ok(Bot {
client,
sync_task_receiver: Some(sync_task_receiver),
event_handler: self.event_handler,
pair_code_options: self.pair_code_options,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TokioRuntime;
use crate::http::{HttpClient, HttpRequest, HttpResponse};
use crate::store::SqliteStore;
use whatsapp_rust_tokio_transport::TokioWebSocketTransportFactory;
#[derive(Debug, Clone)]
struct MockHttpClient;
#[async_trait::async_trait]
impl HttpClient for MockHttpClient {
async fn execute(&self, _request: HttpRequest) -> Result<HttpResponse> {
Ok(HttpResponse {
status_code: 200,
body: br#"self.__swData=JSON.parse(/*BTDS*/"{\"dynamic_data\":{\"SiteData\":{\"server_revision\":1026131876,\"client_revision\":1026131876}}}");"#.to_vec(),
})
}
}
async fn create_test_sqlite_backend() -> Arc<dyn Backend> {
let temp_db = format!(
"file:memdb_bot_{}?mode=memory&cache=shared",
uuid::Uuid::new_v4()
);
Arc::new(
SqliteStore::new(&temp_db)
.await
.expect("Failed to create test SqliteStore"),
) as Arc<dyn Backend>
}
async fn create_test_sqlite_backend_for_device(device_id: i32) -> Arc<dyn Backend> {
let temp_db = format!(
"file:memdb_bot_{}?mode=memory&cache=shared",
uuid::Uuid::new_v4()
);
Arc::new(
SqliteStore::new_for_device(&temp_db, device_id)
.await
.expect("Failed to create test SqliteStore"),
) as Arc<dyn Backend>
}
#[tokio::test]
async fn test_bot_builder_single_device() {
let backend = create_test_sqlite_backend().await;
let transport = TokioWebSocketTransportFactory::new();
let http_client = MockHttpClient;
let bot = Bot::builder()
.with_backend(backend)
.with_transport_factory(transport)
.with_http_client(http_client)
.with_runtime(TokioRuntime)
.build()
.await
.expect("Failed to build bot");
let _client = bot.client();
}
#[tokio::test]
async fn test_bot_builder_multi_device() {
let backend = create_test_sqlite_backend_for_device(42).await;
let transport = TokioWebSocketTransportFactory::new();
let bot = Bot::builder()
.with_backend(backend)
.with_transport_factory(transport)
.with_http_client(MockHttpClient)
.with_runtime(TokioRuntime)
.build()
.await
.expect("Failed to build bot");
let _client = bot.client();
}
#[tokio::test]
async fn test_bot_builder_with_custom_backend() {
let backend = create_test_sqlite_backend().await;
let transport = TokioWebSocketTransportFactory::new();
let http_client = MockHttpClient;
let bot = Bot::builder()
.with_backend(backend)
.with_transport_factory(transport)
.with_http_client(http_client)
.with_runtime(TokioRuntime)
.build()
.await
.expect("Failed to build bot with custom backend");
let _client = bot.client();
}
#[tokio::test]
async fn test_bot_builder_with_custom_backend_specific_device() {
let backend = create_test_sqlite_backend_for_device(100).await;
let transport = TokioWebSocketTransportFactory::new();
let http_client = MockHttpClient;
let bot = Bot::builder()
.with_backend(backend)
.with_http_client(http_client)
.with_transport_factory(transport)
.with_runtime(TokioRuntime)
.build()
.await
.expect("Failed to build bot with custom backend for specific device");
let _client = bot.client();
}
#[tokio::test]
async fn test_bot_builder_with_version_override() {
let backend = create_test_sqlite_backend().await;
let transport = TokioWebSocketTransportFactory::new();
let http_client = MockHttpClient;
let bot = Bot::builder()
.with_backend(backend)
.with_transport_factory(transport)
.with_http_client(http_client)
.with_version((2, 3000, 123456789))
.with_runtime(TokioRuntime)
.build()
.await
.expect("Failed to build bot with version override");
let client = bot.client();
assert_eq!(client.override_version, Some((2, 3000, 123456789)));
}
#[tokio::test]
async fn test_bot_builder_with_device_props_override() {
let backend = create_test_sqlite_backend().await;
let transport = TokioWebSocketTransportFactory::new();
let http_client = MockHttpClient;
let custom_os = "CustomOS".to_string();
let custom_version = wa::device_props::AppVersion {
primary: Some(99),
secondary: Some(88),
tertiary: Some(77),
..Default::default()
};
let bot = Bot::builder()
.with_backend(backend)
.with_transport_factory(transport)
.with_http_client(http_client)
.with_device_props(
DevicePropsOverride::new()
.with_os(custom_os.clone())
.with_version(custom_version),
)
.with_runtime(TokioRuntime)
.build()
.await
.expect("Failed to build bot with device props override");
let client = bot.client();
let persistence_manager = client.persistence_manager();
let device = persistence_manager.get_device_snapshot().await;
assert_eq!(device.device_props.os, Some(custom_os));
assert_eq!(device.device_props.version, Some(custom_version));
}
#[tokio::test]
async fn test_bot_builder_with_os_only_override() {
let backend = create_test_sqlite_backend().await;
let transport = TokioWebSocketTransportFactory::new();
let http_client = MockHttpClient;
let custom_os = "CustomOS".to_string();
let bot = Bot::builder()
.with_backend(backend)
.with_transport_factory(transport)
.with_http_client(http_client)
.with_device_props(DevicePropsOverride::new().with_os(custom_os.clone()))
.with_runtime(TokioRuntime)
.build()
.await
.expect("Failed to build bot with OS only override");
let client = bot.client();
let persistence_manager = client.persistence_manager();
let device = persistence_manager.get_device_snapshot().await;
assert_eq!(device.device_props.os, Some(custom_os));
assert_eq!(
device.device_props.version,
Some(wacore::store::Device::default_device_props_version())
);
}
#[tokio::test]
async fn test_bot_builder_with_version_only_override() {
let backend = create_test_sqlite_backend().await;
let transport = TokioWebSocketTransportFactory::new();
let http_client = MockHttpClient;
let custom_version = wa::device_props::AppVersion {
primary: Some(99),
secondary: Some(88),
tertiary: Some(77),
..Default::default()
};
let bot = Bot::builder()
.with_backend(backend)
.with_http_client(http_client)
.with_transport_factory(transport)
.with_device_props(DevicePropsOverride::new().with_version(custom_version))
.with_runtime(TokioRuntime)
.build()
.await
.expect("Failed to build bot with version only override");
let client = bot.client();
let persistence_manager = client.persistence_manager();
let device = persistence_manager.get_device_snapshot().await;
assert_eq!(device.device_props.version, Some(custom_version));
assert_eq!(
device.device_props.os,
Some(wacore::store::Device::default_os().to_string())
);
}
#[tokio::test]
async fn test_bot_builder_with_platform_type_override() {
let backend = create_test_sqlite_backend().await;
let transport = TokioWebSocketTransportFactory::new();
let http_client = MockHttpClient;
let bot = Bot::builder()
.with_backend(backend)
.with_transport_factory(transport)
.with_http_client(http_client)
.with_device_props(
DevicePropsOverride::new()
.with_platform_type(wa::device_props::PlatformType::Chrome),
)
.with_runtime(TokioRuntime)
.build()
.await
.expect("Failed to build bot with platform type override");
let client = bot.client();
let persistence_manager = client.persistence_manager();
let device = persistence_manager.get_device_snapshot().await;
assert_eq!(
device.device_props.platform_type,
Some(wa::device_props::PlatformType::Chrome as i32)
);
assert_eq!(
device.device_props.os,
Some(wacore::store::Device::default_os().to_string())
);
assert_eq!(
device.device_props.version,
Some(wacore::store::Device::default_device_props_version())
);
}
#[tokio::test]
async fn test_bot_builder_with_full_device_props_override() {
let backend = create_test_sqlite_backend().await;
let transport = TokioWebSocketTransportFactory::new();
let http_client = MockHttpClient;
let custom_os = "macOS".to_string();
let custom_version = wa::device_props::AppVersion {
primary: Some(2),
secondary: Some(0),
tertiary: Some(0),
..Default::default()
};
let custom_platform = wa::device_props::PlatformType::Safari;
let bot = Bot::builder()
.with_backend(backend)
.with_transport_factory(transport)
.with_http_client(http_client)
.with_device_props(
DevicePropsOverride::new()
.with_os(custom_os.clone())
.with_version(custom_version)
.with_platform_type(custom_platform),
)
.with_runtime(TokioRuntime)
.build()
.await
.expect("Failed to build bot with full device props override");
let client = bot.client();
let persistence_manager = client.persistence_manager();
let device = persistence_manager.get_device_snapshot().await;
assert_eq!(device.device_props.os, Some(custom_os));
assert_eq!(device.device_props.version, Some(custom_version));
assert_eq!(
device.device_props.platform_type,
Some(custom_platform as i32)
);
}
#[tokio::test]
async fn test_bot_builder_skip_history_sync() {
let backend = create_test_sqlite_backend().await;
let transport = TokioWebSocketTransportFactory::new();
let http_client = MockHttpClient;
let bot = Bot::builder()
.with_backend(backend)
.with_transport_factory(transport)
.with_http_client(http_client)
.skip_history_sync()
.with_runtime(TokioRuntime)
.build()
.await
.expect("Failed to build bot with skip_history_sync");
assert!(bot.client().skip_history_sync_enabled());
}
#[tokio::test]
async fn test_bot_builder_default_history_sync_enabled() {
let backend = create_test_sqlite_backend().await;
let transport = TokioWebSocketTransportFactory::new();
let http_client = MockHttpClient;
let bot = Bot::builder()
.with_backend(backend)
.with_transport_factory(transport)
.with_http_client(http_client)
.with_runtime(TokioRuntime)
.build()
.await
.expect("Failed to build bot");
assert!(!bot.client().skip_history_sync_enabled());
}
#[tokio::test]
async fn from_arc_does_not_deep_clone() {
let backend = create_test_sqlite_backend().await;
let bot = Bot::builder()
.with_backend(backend)
.with_transport_factory(TokioWebSocketTransportFactory::new())
.with_http_client(MockHttpClient)
.with_runtime(TokioRuntime)
.build()
.await
.expect("Failed to build bot");
let original = Arc::new(wa::Message {
conversation: Some("ping".to_string()),
..Default::default()
});
let original_ptr = Arc::as_ptr(&original);
let ctx =
MessageContext::from_arc(Arc::clone(&original), &MessageInfo::default(), bot.client());
assert!(std::ptr::eq(Arc::as_ptr(&ctx.message), original_ptr));
}
}