Skip to main content

dbus_router/
session.rs

1//! Client session handling with dual upstream connections and routing
2
3use crate::auth;
4use crate::config::Config;
5use crate::dbus_daemon::{
6    build_list_names_response, merge_list_names, needs_request_rewrite, needs_response_rewrite,
7    parse_string_array, rewrite_match_rule_body, rewrite_name_owner_changed,
8    rewrite_single_name_response, rewrite_string_array_response, rewrite_unique_name_request,
9    signal_needs_rewrite,
10};
11use crate::fake_name::get_bus_from_fake_name;
12use crate::message::{self, read_message, Message, MessageType};
13use crate::message_format::format_message;
14use crate::message_rewrite::{parse_match_rule, rewrite_message_header, RewriteDirection};
15use anyhow::{bail, Result};
16use std::collections::{HashMap, HashSet};
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use tokio::io::AsyncWriteExt;
20use tokio::net::UnixStream;
21
22/// Target bus for routing.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum Bus {
25    Host,
26    Sandbox,
27}
28
29/// Log a message with all relevant fields for debugging.
30fn log_message(msg: &Message, direction: &str, target: Option<Bus>) {
31    let target_str = match target {
32        Some(Bus::Host) => " -> Host",
33        Some(Bus::Sandbox) => " -> Sandbox",
34        None => "",
35    };
36
37    tracing::trace!(
38        direction = direction,
39        target = target_str,
40        msg_type = ?msg.header.msg_type,
41        serial = msg.header.serial,
42        reply_serial = ?msg.header.reply_serial,
43        sender = ?msg.header.sender,
44        destination = ?msg.header.destination,
45        interface = ?msg.header.interface,
46        member = ?msg.header.member,
47        body_len = msg.header.body_len,
48        "{}",
49        format_message_summary(msg)
50    );
51
52    // Also emit dbus-monitor style log with separate target for easy filtering
53    tracing::trace!(
54        target: "dbus_monitor",
55        "{}",
56        format_message(msg, direction, target)
57    );
58}
59
60/// Format a one-line summary of the message for logging.
61fn format_message_summary(msg: &Message) -> String {
62    let type_str = match msg.header.msg_type {
63        MessageType::MethodCall => "CALL",
64        MessageType::MethodReturn => "REPLY",
65        MessageType::Error => "ERROR",
66        MessageType::Signal => "SIGNAL",
67        MessageType::Invalid => "INVALID",
68    };
69
70    let interface = msg.header.interface.as_deref().unwrap_or("-");
71    let member = msg.header.member.as_deref().unwrap_or("-");
72    let dest = msg.header.destination.as_deref().unwrap_or("-");
73    let sender = msg.header.sender.as_deref().unwrap_or("-");
74
75    format!(
76        "{} {}.{} [{}->{}] serial={}",
77        type_str, interface, member, sender, dest, msg.header.serial
78    )
79}
80
81/// Prepare a message from an upstream bus for forwarding to the client.
82/// Rewrites sender header and body as needed.
83fn prepare_message_for_client(
84    msg: &Message,
85    source_bus: Bus,
86    pending_calls: &HashMap<u32, PendingCallInfo>,
87) -> Vec<u8> {
88    let mut msg_for_client = msg.clone();
89
90    // Log message context for debugging header rewriting issues
91    tracing::trace!(
92        msg_type = ?msg.header.msg_type,
93        sender = ?msg.header.sender,
94        destination = ?msg.header.destination,
95        interface = ?msg.header.interface,
96        member = ?msg.header.member,
97        source_bus = ?source_bus,
98        "Preparing message for client"
99    );
100
101    // Rewrite sender header to add bus prefix
102    if let Err(e) =
103        rewrite_message_header(&mut msg_for_client, RewriteDirection::ToClient, source_bus)
104    {
105        tracing::warn!(
106            error = %e,
107            msg_type = ?msg.header.msg_type,
108            sender = ?msg.header.sender,
109            destination = ?msg.header.destination,
110            interface = ?msg.header.interface,
111            member = ?msg.header.member,
112            "Failed to rewrite sender header"
113        );
114    }
115
116    // Handle body rewriting based on message type
117    match msg_for_client.header.msg_type {
118        MessageType::MethodReturn => {
119            rewrite_method_return_body(&msg_for_client, source_bus, pending_calls)
120        }
121        MessageType::Signal => rewrite_signal_body(&msg_for_client, source_bus),
122        _ => msg_for_client.raw,
123    }
124}
125
126/// Rewrite method return body if it contains unique names that need prefixing.
127fn rewrite_method_return_body(
128    msg: &Message,
129    source_bus: Bus,
130    pending_calls: &HashMap<u32, PendingCallInfo>,
131) -> Vec<u8> {
132    let Some(reply_serial) = msg.header.reply_serial else {
133        return msg.raw.clone();
134    };
135
136    let Some(call_info) = pending_calls.get(&reply_serial) else {
137        return msg.raw.clone();
138    };
139
140    if call_info.bus != source_bus {
141        return msg.raw.clone();
142    }
143
144    let Some(ref member) = call_info.member else {
145        return msg.raw.clone();
146    };
147
148    if !needs_response_rewrite(member) {
149        return msg.raw.clone();
150    }
151
152    // ListQueuedOwners returns an array of unique names
153    let result = if member == "ListQueuedOwners" {
154        rewrite_string_array_response(msg, source_bus)
155    } else {
156        rewrite_single_name_response(msg, source_bus)
157    };
158
159    match result {
160        Ok(rewritten) => {
161            tracing::trace!(member = member, bus = ?source_bus, "Rewrote response body");
162            rewritten
163        }
164        Err(e) => {
165            tracing::warn!(member = member, error = %e, "Failed to rewrite response body");
166            msg.raw.clone()
167        }
168    }
169}
170
171/// Rewrite signal body if it contains unique names that need prefixing.
172fn rewrite_signal_body(msg: &Message, source_bus: Bus) -> Vec<u8> {
173    let Some(ref member) = msg.header.member else {
174        return msg.raw.clone();
175    };
176
177    if !signal_needs_rewrite(member) {
178        return msg.raw.clone();
179    }
180
181    match rewrite_name_owner_changed(msg, source_bus) {
182        Ok(rewritten) => {
183            tracing::trace!(member = member, bus = ?source_bus, "Rewrote signal body");
184            rewritten
185        }
186        Err(e) => {
187            tracing::warn!(member = member, error = %e, "Failed to rewrite signal body");
188            msg.raw.clone()
189        }
190    }
191}
192
193/// Result of processing a merge response.
194enum MergeResult {
195    /// First response stored, waiting for second
196    Stored,
197    /// Second response received, returns merged message bytes
198    Complete(Vec<u8>),
199}
200
201/// Process a ListNames/ListActivatableNames merge response.
202/// Returns `Some(MergeResult)` if the message was a merge response, `None` otherwise.
203fn process_merge_response(
204    msg: &Message,
205    source_bus: Bus,
206    pending_merges: &mut HashMap<u32, PendingMerge>,
207) -> Option<MergeResult> {
208    let reply_serial = msg.header.reply_serial?;
209    let pending = pending_merges.get_mut(&reply_serial)?;
210
211    let names = parse_string_array(msg).unwrap_or_default();
212    tracing::trace!(
213        serial = reply_serial,
214        count = names.len(),
215        bus = ?source_bus,
216        "Received ListNames response"
217    );
218
219    let Some((first_bus, first_names)) = pending.first_response.take() else {
220        // First response - store it
221        pending.first_response = Some((source_bus, names));
222        return Some(MergeResult::Stored);
223    };
224
225    // Second response - merge and build final message
226    let (host_names, sandbox_names) = if first_bus == Bus::Host {
227        (first_names, names)
228    } else {
229        (names, first_names)
230    };
231
232    let merged = merge_list_names(host_names, sandbox_names);
233    tracing::trace!(
234        serial = reply_serial,
235        count = merged.len(),
236        "Merged ListNames response"
237    );
238
239    let pending = pending_merges.remove(&reply_serial).unwrap();
240    let response = match build_list_names_response(&pending.original_request, merged) {
241        Ok(response) => response,
242        Err(e) => {
243            tracing::warn!(error = %e, "Failed to build merged response");
244            msg.raw.clone()
245        }
246    };
247
248    Some(MergeResult::Complete(response))
249}
250
251/// Pending merge state for ListNames/ListActivatableNames
252#[derive(Debug)]
253struct PendingMerge {
254    /// The original request message (for building response)
255    original_request: Message,
256    /// First response received (None if waiting for first)
257    first_response: Option<(Bus, Vec<String>)>,
258}
259
260/// Info about a pending call (for response rewriting)
261#[derive(Debug, Clone)]
262struct PendingCallInfo {
263    /// Which bus the call was sent to
264    bus: Bus,
265    /// The member (method name) of the call (for response rewriting)
266    member: Option<String>,
267}
268
269/// A client session with connections to both upstream buses.
270pub struct Session {
271    /// Connection from sandbox app
272    client: UnixStream,
273    /// Connection to host session bus
274    host_bus: UnixStream,
275    /// Connection to sandbox session bus (default target)
276    sandbox_bus: UnixStream,
277    /// Routing configuration
278    config: Arc<Config>,
279    /// Track which bus each outgoing serial was sent to (for routing replies)
280    pending_calls: HashMap<u32, PendingCallInfo>,
281    /// Client process executable path (for sandbox export permission check)
282    client_exe_path: Option<PathBuf>,
283    /// Services exported by this client to the host bus
284    exported_services: HashSet<String>,
285    /// Services registered by this client on the sandbox bus
286    sandbox_services: HashSet<String>,
287    /// Track incoming calls from upstream buses (serial -> source bus)
288    incoming_calls: HashMap<u32, Bus>,
289    /// Pending ListNames/ListActivatableNames merges (serial -> state)
290    pending_merges: HashMap<u32, PendingMerge>,
291}
292
293/// Get the executable path of a peer process from a Unix socket.
294#[cfg(target_os = "linux")]
295fn get_peer_exe_path(stream: &UnixStream) -> Option<PathBuf> {
296    use std::os::unix::io::AsRawFd;
297
298    let fd = stream.as_raw_fd();
299
300    // Get peer credentials using SO_PEERCRED
301    let mut ucred: libc::ucred = unsafe { std::mem::zeroed() };
302    let mut len = std::mem::size_of::<libc::ucred>() as libc::socklen_t;
303
304    let ret = unsafe {
305        libc::getsockopt(
306            fd,
307            libc::SOL_SOCKET,
308            libc::SO_PEERCRED,
309            &mut ucred as *mut _ as *mut libc::c_void,
310            &mut len,
311        )
312    };
313
314    if ret != 0 {
315        tracing::debug!(
316            "Failed to get peer credentials: {}",
317            std::io::Error::last_os_error()
318        );
319        return None;
320    }
321
322    let pid = ucred.pid;
323    if pid <= 0 {
324        return None;
325    }
326
327    // Read /proc/{pid}/exe symlink
328    let exe_path = format!("/proc/{}/exe", pid);
329    match std::fs::read_link(&exe_path) {
330        Ok(path) => {
331            tracing::debug!(pid = pid, exe = %path.display(), "Got peer exe path");
332            Some(path)
333        }
334        Err(e) => {
335            tracing::debug!(pid = pid, error = %e, "Failed to read exe path");
336            None
337        }
338    }
339}
340
341#[cfg(not(target_os = "linux"))]
342fn get_peer_exe_path(_stream: &UnixStream) -> Option<PathBuf> {
343    None
344}
345
346impl Session {
347    /// Create a new session with connections to both upstream buses.
348    pub async fn new(
349        client: UnixStream,
350        host_addr: &str,
351        sandbox_addr: &str,
352        config: Arc<Config>,
353    ) -> Result<Self> {
354        let client_exe_path = get_peer_exe_path(&client);
355        let host_bus = connect_dbus(host_addr).await?;
356        let sandbox_bus = connect_dbus(sandbox_addr).await?;
357
358        Ok(Self {
359            client,
360            host_bus,
361            sandbox_bus,
362            config,
363            pending_calls: HashMap::new(),
364            client_exe_path,
365            exported_services: HashSet::new(),
366            sandbox_services: HashSet::new(),
367            incoming_calls: HashMap::new(),
368            pending_merges: HashMap::new(),
369        })
370    }
371
372    /// Run the session: authenticate with both buses, then forward messages.
373    pub async fn run(mut self) -> Result<()> {
374        // Check if this is a hostpass client
375        let is_hostpass = self
376            .client_exe_path
377            .as_ref()
378            .map(|p| self.config.has_hostpass(p))
379            .unwrap_or(false);
380
381        // Phase 1: Auth passthrough with sandbox bus
382        tracing::debug!("Starting auth phase with sandbox bus");
383        auth::auth_passthrough(&mut self.client, &mut self.sandbox_bus).await?;
384        tracing::info!("Auth with sandbox bus completed");
385
386        // Phase 1b: Also authenticate with host bus (using same credentials)
387        // For hostpass clients, skip Hello() - their Hello() will be forwarded to host bus
388        // For non-hostpass clients, send Hello() so host_routes messages can be routed
389        tracing::debug!("Starting auth phase with host bus");
390        self.auth_host_bus(is_hostpass).await?;
391        tracing::info!("Auth with host bus completed, starting message forwarding");
392
393        // Phase 2: Message forwarding with routing
394        self.forward_loop().await
395    }
396
397    /// Authenticate with the host bus.
398    /// The host bus needs its own auth handshake.
399    /// If `skip_hello` is true, skip the Hello() call (for hostpass clients whose
400    /// Hello() will be forwarded to the host bus).
401    async fn auth_host_bus(&mut self, skip_hello: bool) -> Result<()> {
402        // Send null byte and EXTERNAL auth with hex-encoded UID
403        self.host_bus.write_all(&[0]).await?;
404        let uid = unsafe { libc::getuid() };
405        let uid_hex: String = uid
406            .to_string()
407            .bytes()
408            .map(|b| format!("{:02x}", b))
409            .collect();
410        self.host_bus
411            .write_all(format!("AUTH EXTERNAL {}\r\n", uid_hex).as_bytes())
412            .await?;
413
414        // Read auth response
415        let response = read_auth_line(&mut self.host_bus).await?;
416        if !response.starts_with("OK") {
417            bail!("Host bus auth failed: {}", response.trim());
418        }
419
420        // Negotiate UNIX FD passing
421        self.host_bus.write_all(b"NEGOTIATE_UNIX_FD\r\n").await?;
422        let response = read_auth_line(&mut self.host_bus).await?;
423        tracing::debug!(response = %response.trim(), "Host bus NEGOTIATE_UNIX_FD response");
424
425        // Send BEGIN to complete SASL auth
426        self.host_bus.write_all(b"BEGIN\r\n").await?;
427
428        if skip_hello {
429            tracing::debug!("Skipping Hello() for hostpass client");
430            Ok(())
431        } else {
432            self.send_host_hello().await
433        }
434    }
435
436    /// Send Hello() method call to host bus and read the response.
437    /// This registers the router's connection with the host bus daemon.
438    async fn send_host_hello(&mut self) -> Result<()> {
439        use zvariant::{serialized::Context, to_bytes, ObjectPath, Value, LE};
440
441        // Build header fields array for Hello() call
442        let path = ObjectPath::try_from("/org/freedesktop/DBus").unwrap();
443        let fields: Vec<(u8, Value)> = vec![
444            (1, Value::ObjectPath(path)),                   // PATH
445            (2, Value::Str("org.freedesktop.DBus".into())), // INTERFACE
446            (3, Value::Str("Hello".into())),                // MEMBER
447            (6, Value::Str("org.freedesktop.DBus".into())), // DESTINATION
448        ];
449
450        let ctxt = Context::new_dbus(LE, 12);
451        let fields_encoded = to_bytes(ctxt, &fields)?;
452        let array_len = fields_encoded.len() - 4; // Exclude 4-byte length prefix
453
454        // Calculate padding to 8-byte boundary
455        let header_end = 16 + array_len;
456        let padding = (8 - (header_end % 8)) % 8;
457
458        // Build D-Bus message: fixed header + fields + padding
459        let mut msg = Vec::with_capacity(16 + array_len + padding);
460        msg.extend_from_slice(&[b'l', 1, 0, 1]); // endian, method_call, flags, version
461        msg.extend_from_slice(&0u32.to_le_bytes()); // body length
462        msg.extend_from_slice(&1u32.to_le_bytes()); // serial
463        msg.extend_from_slice(&fields_encoded);
464        msg.resize(msg.len() + padding, 0);
465
466        self.host_bus.write_all(&msg).await?;
467
468        // Read and validate response.
469        // After Hello(), the bus daemon sends:
470        // 1. MethodReturn with our unique name
471        // 2. NameAcquired signal for that name
472        // We must consume both to prevent the signal from being forwarded to client,
473        // which would cause "Unexpected message Signal" errors in strict clients like zbus.
474        //
475        // First, read the MethodReturn
476        match message::read_message(&mut self.host_bus).await? {
477            Some(resp) if resp.header.msg_type == MessageType::MethodReturn => {
478                tracing::debug!("Host bus Hello() MethodReturn received");
479            }
480            Some(resp) if resp.header.msg_type == MessageType::Error => {
481                bail!("Host bus Hello() failed with error")
482            }
483            Some(resp) => bail!("Unexpected response to Hello(): {:?}", resp.header.msg_type),
484            None => bail!("Host bus disconnected after Hello()"),
485        }
486
487        // Then, read the NameAcquired signal
488        match message::read_message(&mut self.host_bus).await? {
489            Some(resp) if resp.header.msg_type == MessageType::Signal => {
490                tracing::debug!(
491                    interface = ?resp.header.interface,
492                    member = ?resp.header.member,
493                    "Consumed NameAcquired signal after Hello()"
494                );
495            }
496            Some(resp) => {
497                tracing::warn!(
498                    msg_type = ?resp.header.msg_type,
499                    "Unexpected second message after Hello(), expected Signal"
500                );
501            }
502            None => bail!("Host bus disconnected after Hello()"),
503        }
504
505        tracing::debug!("Host bus Hello() completed");
506        Ok(())
507    }
508
509    /// Forward messages between client and upstream buses with routing.
510    async fn forward_loop(mut self) -> Result<()> {
511        // Check if this is a hostpass client (they only use host bus)
512        let is_hostpass = self
513            .client_exe_path
514            .as_ref()
515            .map(|p| self.config.has_hostpass(p))
516            .unwrap_or(false);
517
518        let (client_read, mut client_write) = self.client.split();
519        let (host_read, mut host_write) = self.host_bus.split();
520        let (sandbox_read, mut sandbox_write) = self.sandbox_bus.split();
521
522        let mut client_read = tokio::io::BufReader::new(client_read);
523        let mut host_read = tokio::io::BufReader::new(host_read);
524        let mut sandbox_read = tokio::io::BufReader::new(sandbox_read);
525
526        // Track if sandbox bus is still active (for hostpass clients that survive sandbox disconnect)
527        let mut sandbox_active = true;
528
529        loop {
530            tokio::select! {
531                biased;
532                // Read from client and route to appropriate bus
533                result = read_message(&mut client_read) => {
534                    match result {
535                        Ok(Some(msg)) => {
536                            // Check if this is a reply to an incoming call from upstream bus
537                            if matches!(msg.header.msg_type, MessageType::MethodReturn | MessageType::Error) {
538                                if let Some(reply_serial) = msg.header.reply_serial {
539                                    if let Some(source_bus) = self.incoming_calls.remove(&reply_serial) {
540                                        // Rewrite destination from fake name to real name
541                                        let mut msg_for_bus = msg.clone();
542                                        if let Err(e) = rewrite_message_header(
543                                            &mut msg_for_bus,
544                                            RewriteDirection::ToUpstream,
545                                            source_bus,
546                                        ) {
547                                            tracing::warn!(error = %e, "Failed to rewrite destination for reply");
548                                        }
549
550                                        match source_bus {
551                                            Bus::Host => {
552                                                log_message(&msg, "Client->Host(reply)", Some(Bus::Host));
553                                                host_write.write_all(&msg_for_bus.raw).await?;
554                                            }
555                                            Bus::Sandbox => {
556                                                log_message(&msg, "Client->Sandbox(reply)", Some(Bus::Sandbox));
557                                                sandbox_write.write_all(&msg_for_bus.raw).await?;
558                                            }
559                                        }
560                                        continue;
561                                    }
562                                }
563                            }
564
565                            let decision = route_request(&self.config, &msg, self.client_exe_path.as_deref());
566
567                            // Track RequestName calls to know which services this client exports
568                            if msg.is_request_name() {
569                                if let Some(name) = msg.extract_name_from_body() {
570                                    match decision {
571                                        RouteDecision::Single(Bus::Host) => {
572                                            tracing::info!(service = %name, "Client exporting service to host bus");
573                                            self.exported_services.insert(name);
574                                        }
575                                        RouteDecision::Single(Bus::Sandbox) => {
576                                            tracing::info!(service = %name, "Client registering service on sandbox bus");
577                                            self.sandbox_services.insert(name);
578                                        }
579                                        _ => {}
580                                    }
581                                }
582                            }
583
584                            match decision {
585                                RouteDecision::Single(target) => {
586                                    // Prepare message for sending to upstream
587                                    let mut msg_for_upstream = msg.clone();
588
589                                    // 1. Rewrite destination header if it's a fake unique name
590                                    if let Err(e) = rewrite_message_header(
591                                        &mut msg_for_upstream,
592                                        RewriteDirection::ToUpstream,
593                                        target, // source_bus not used for ToUpstream
594                                    ) {
595                                        tracing::warn!(error = %e, "Failed to rewrite destination header");
596                                    }
597
598                                    // 2. Rewrite body for org.freedesktop.DBus methods
599                                    let msg_to_send = if msg_for_upstream.header.destination.as_deref() == Some("org.freedesktop.DBus") {
600                                        if let Some(member) = msg_for_upstream.header.member.as_deref() {
601                                            if needs_request_rewrite(member) {
602                                                // Rewrite unique name in body (GetConnectionCredentials etc.)
603                                                match rewrite_unique_name_request(&msg_for_upstream) {
604                                                    Ok((rewritten, _bus)) => {
605                                                        tracing::trace!(
606                                                            member = member,
607                                                            "Rewrote request body to remove fake prefix"
608                                                        );
609                                                        rewritten
610                                                    }
611                                                    Err(e) => {
612                                                        tracing::warn!(
613                                                            member = member,
614                                                            error = %e,
615                                                            "Failed to rewrite request body"
616                                                        );
617                                                        msg_for_upstream.raw.clone()
618                                                    }
619                                                }
620                                            } else if member == "AddMatch" || member == "RemoveMatch" {
621                                                // Rewrite sender in match rule
622                                                match rewrite_match_rule_body(&msg_for_upstream) {
623                                                    Ok(Some(rewritten)) => {
624                                                        tracing::trace!(
625                                                            member = member,
626                                                            "Rewrote match rule sender"
627                                                        );
628                                                        rewritten
629                                                    }
630                                                    Ok(None) => msg_for_upstream.raw.clone(),
631                                                    Err(e) => {
632                                                        tracing::warn!(
633                                                            member = member,
634                                                            error = %e,
635                                                            "Failed to rewrite match rule"
636                                                        );
637                                                        msg_for_upstream.raw.clone()
638                                                    }
639                                                }
640                                            } else {
641                                                msg_for_upstream.raw.clone()
642                                            }
643                                        } else {
644                                            msg_for_upstream.raw.clone()
645                                        }
646                                    } else {
647                                        msg_for_upstream.raw.clone()
648                                    };
649
650                                    self.pending_calls.insert(msg.header.serial, PendingCallInfo {
651                                        bus: target,
652                                        member: msg.header.member.clone(),
653                                    });
654                                    log_message(&msg, "Client", Some(target));
655
656                                    match target {
657                                        Bus::Host => host_write.write_all(&msg_to_send).await?,
658                                        Bus::Sandbox => sandbox_write.write_all(&msg_to_send).await?,
659                                    }
660                                }
661                                RouteDecision::Both => {
662                                    // Send to both buses (e.g., AddMatch without sender)
663                                    // Rewrite match rule body if needed
664                                    let msg_to_send = if msg.header.member.as_deref() == Some("AddMatch")
665                                        || msg.header.member.as_deref() == Some("RemoveMatch")
666                                    {
667                                        match rewrite_match_rule_body(&msg) {
668                                            Ok(Some(rewritten)) => rewritten,
669                                            Ok(None) => msg.raw.clone(),
670                                            Err(_) => msg.raw.clone(),
671                                        }
672                                    } else {
673                                        msg.raw.clone()
674                                    };
675
676                                    self.pending_calls.insert(msg.header.serial, PendingCallInfo {
677                                        bus: Bus::Sandbox,
678                                        member: msg.header.member.clone(),
679                                    });
680                                    log_message(&msg, "Client->Both", None);
681
682                                    host_write.write_all(&msg_to_send).await?;
683                                    sandbox_write.write_all(&msg_to_send).await?;
684                                }
685                                RouteDecision::Merge => {
686                                    // ListNames/ListActivatableNames: send to both and merge results
687                                    log_message(&msg, "Client->Merge", None);
688
689                                    self.pending_merges.insert(msg.header.serial, PendingMerge {
690                                        original_request: msg.clone(),
691                                        first_response: None,
692                                    });
693
694                                    host_write.write_all(&msg.raw).await?;
695                                    sandbox_write.write_all(&msg.raw).await?;
696                                }
697                            }
698                        }
699                        Ok(None) => {
700                            tracing::debug!("Client disconnected");
701                            return Ok(());
702                        }
703                        Err(e) => {
704                            tracing::debug!(error = %e, "Error reading from client");
705                            return Ok(());
706                        }
707                    }
708                }
709
710                // Read from host bus and forward to client
711                result = read_message(&mut host_read) => {
712                    match result {
713                        Ok(Some(msg)) => {
714                            // Check if this is a call to an exported service
715                            if msg.header.msg_type == MessageType::MethodCall {
716                                if let Some(ref dest) = msg.header.destination {
717                                    if self.exported_services.contains(dest) {
718                                        log_message(&msg, "Host->Client(exported)", None);
719                                        self.incoming_calls.insert(msg.header.serial, Bus::Host);
720                                        client_write.write_all(&msg.raw).await?;
721                                        continue;
722                                    }
723                                }
724                            }
725
726                            // Check if this is a response to a pending merge request
727                            if let Some(result) = process_merge_response(&msg, Bus::Host, &mut self.pending_merges) {
728                                if let MergeResult::Complete(response) = result {
729                                    client_write.write_all(&response).await?;
730                                }
731                                continue;
732                            }
733
734                            let msg_to_send = prepare_message_for_client(
735                                &msg,
736                                Bus::Host,
737                                &self.pending_calls,
738                            );
739                            log_message(&msg, "Host->Client", None);
740                            client_write.write_all(&msg_to_send).await?;
741                        }
742                        Ok(None) => {
743                            tracing::debug!("Host bus disconnected");
744                            return Ok(());
745                        }
746                        Err(e) => {
747                            tracing::debug!(error = %e, "Error reading from host bus");
748                            return Ok(());
749                        }
750                    }
751                }
752
753                // Read from sandbox bus and forward to client
754                result = read_message(&mut sandbox_read), if sandbox_active => {
755                    match result {
756                        Ok(Some(msg)) => {
757                            // Check if this is a response to a pending merge request
758                            if let Some(result) = process_merge_response(&msg, Bus::Sandbox, &mut self.pending_merges) {
759                                if let MergeResult::Complete(response) = result {
760                                    client_write.write_all(&response).await?;
761                                }
762                                continue;
763                            }
764
765                            // Check if this is a call to a sandbox service registered by this client
766                            if msg.header.msg_type == MessageType::MethodCall {
767                                if let Some(ref dest) = msg.header.destination {
768                                    if self.sandbox_services.contains(dest) {
769                                        // Rewrite sender to add :s. prefix before forwarding to client
770                                        let mut msg_for_client = msg.clone();
771                                        if let Err(e) = rewrite_message_header(
772                                            &mut msg_for_client,
773                                            RewriteDirection::ToClient,
774                                            Bus::Sandbox,
775                                        ) {
776                                            tracing::warn!(error = %e, "Failed to rewrite sender header for sandbox service call");
777                                        }
778
779                                        log_message(&msg, "Sandbox->Client(service)", None);
780                                        self.incoming_calls.insert(msg.header.serial, Bus::Sandbox);
781                                        client_write.write_all(&msg_for_client.raw).await?;
782                                        continue;
783                                    }
784                                }
785                            }
786
787                            let msg_to_send = prepare_message_for_client(
788                                &msg,
789                                Bus::Sandbox,
790                                &self.pending_calls,
791                            );
792                            log_message(&msg, "Sandbox->Client", None);
793                            client_write.write_all(&msg_to_send).await?;
794                        }
795                        Ok(None) => {
796                            tracing::debug!("Sandbox bus disconnected");
797                            if is_hostpass {
798                                // Hostpass clients only use host bus, so they can continue
799                                tracing::info!("Hostpass client continues after sandbox disconnect");
800                                sandbox_active = false;
801                                continue;
802                            }
803                            return Ok(());
804                        }
805                        Err(e) => {
806                            tracing::debug!(error = %e, "Error reading from sandbox bus");
807                            if is_hostpass {
808                                tracing::info!("Hostpass client continues after sandbox read error");
809                                sandbox_active = false;
810                                continue;
811                            }
812                            return Ok(());
813                        }
814                    }
815                }
816            }
817        }
818    }
819}
820
821/// Routing decision result
822#[derive(Debug, Clone)]
823pub enum RouteDecision {
824    /// Route to a single bus
825    Single(Bus),
826    /// Route to both buses (for AddMatch without sender)
827    Both,
828    /// Route to both buses and merge results (for ListNames)
829    Merge,
830}
831
832/// Determine which bus to route a request to based on destination.
833fn route_request(config: &Config, msg: &Message, client_exe: Option<&Path>) -> RouteDecision {
834    // Hostpass: route ALL messages from hostpass processes to host bus
835    // This is required because D-Bus requires Hello() before any other operations,
836    // and the host bus won't accept messages from connections that haven't called Hello()
837    if let Some(exe) = client_exe {
838        if config.has_hostpass(exe) {
839            tracing::trace!(exe = %exe.display(), "Routing to host (hostpass)");
840            return RouteDecision::Single(Bus::Host);
841        }
842    }
843
844    // Check if destination is a fake unique name (e.g., :h.1.45)
845    if let Some(ref dest) = msg.header.destination {
846        if let Some(bus) = get_bus_from_fake_name(dest) {
847            tracing::trace!(dest = %dest, bus = ?bus, "Routing by fake unique name");
848            return RouteDecision::Single(bus);
849        }
850    }
851
852    // Special handling for org.freedesktop.DBus methods
853    if msg.header.destination.as_deref() == Some("org.freedesktop.DBus")
854        && msg.header.interface.as_deref() == Some("org.freedesktop.DBus")
855    {
856        return route_dbus_daemon_call(config, msg);
857    }
858
859    // Method calls and signals with a destination are routed based on config
860    if let Some(ref dest) = msg.header.destination {
861        if config.should_route_to_host(dest) {
862            return RouteDecision::Single(Bus::Host);
863        }
864    }
865
866    // Default: route to sandbox bus
867    RouteDecision::Single(Bus::Sandbox)
868}
869
870/// Route calls to org.freedesktop.DBus based on method and arguments
871fn route_dbus_daemon_call(config: &Config, msg: &Message) -> RouteDecision {
872    let member = msg.header.member.as_deref().unwrap_or("");
873
874    match member {
875        // AddMatch/RemoveMatch: route based on sender in the match rule
876        "AddMatch" | "RemoveMatch" => {
877            if let Some(rule) = msg.extract_string_from_body() {
878                let info = parse_match_rule(&rule);
879
880                // If sender is specified, route based on sender
881                if let Some(ref sender) = info.sender {
882                    // Check if sender is a fake unique name
883                    if let Some(bus) = get_bus_from_fake_name(sender) {
884                        tracing::trace!(
885                            member = member,
886                            sender = %sender,
887                            bus = ?bus,
888                            "Routing {} by fake sender name",
889                            member
890                        );
891                        return RouteDecision::Single(bus);
892                    }
893
894                    // Check if sender matches host_routes
895                    if config.should_route_to_host(sender) {
896                        tracing::trace!(
897                            member = member,
898                            sender = %sender,
899                            "Routing {} to host (sender in host_routes)",
900                            member
901                        );
902                        return RouteDecision::Single(Bus::Host);
903                    }
904                }
905
906                // If interface is specified but no sender, try to route by interface
907                if let Some(ref interface) = info.interface {
908                    // Extract service name from interface (e.g., org.fcitx.Fcitx5.Controller1 -> org.fcitx.Fcitx5)
909                    if let Some(service) = extract_service_from_interface(interface) {
910                        if config.should_route_to_host(&service) {
911                            tracing::trace!(
912                                member = member,
913                                interface = %interface,
914                                service = %service,
915                                "Routing {} to host (interface matches host_routes)",
916                                member
917                            );
918                            return RouteDecision::Single(Bus::Host);
919                        }
920                    }
921                }
922
923                // No sender or interface to determine routing - send to both buses
924                tracing::trace!(
925                    member = member,
926                    rule = %rule,
927                    "Routing {} to both buses (no sender specified)",
928                    member
929                );
930                return RouteDecision::Both;
931            }
932
933            // Can't parse rule, default to sandbox
934            RouteDecision::Single(Bus::Sandbox)
935        }
936
937        // RequestName/ReleaseName: route based on the service name being registered
938        "RequestName" | "ReleaseName" => {
939            if let Some(name) = msg.extract_name_from_body() {
940                if config.should_route_to_host(&name) {
941                    tracing::trace!(
942                        member = member,
943                        name = %name,
944                        "Routing {} to host (name in host_routes)",
945                        member
946                    );
947                    return RouteDecision::Single(Bus::Host);
948                }
949            }
950            RouteDecision::Single(Bus::Sandbox)
951        }
952
953        // GetNameOwner/NameHasOwner/StartServiceByName: route based on queried name
954        "GetNameOwner" | "NameHasOwner" | "StartServiceByName" => {
955            if let Some(name) = msg.extract_name_from_body() {
956                // Check if name is a fake unique name
957                if let Some(bus) = get_bus_from_fake_name(&name) {
958                    return RouteDecision::Single(bus);
959                }
960
961                if config.should_route_to_host(&name) {
962                    tracing::trace!(
963                        member = member,
964                        name = %name,
965                        "Routing {} to host (name in host_routes)",
966                        member
967                    );
968                    return RouteDecision::Single(Bus::Host);
969                }
970            }
971            RouteDecision::Single(Bus::Sandbox)
972        }
973
974        // GetConnectionCredentials etc: route based on unique name argument
975        "GetConnectionCredentials"
976        | "GetConnectionUnixUser"
977        | "GetConnectionUnixProcessID"
978        | "GetConnectionSELinuxSecurityContext"
979        | "GetAdtAuditSessionData" => {
980            if let Some(name) = msg.extract_name_from_body() {
981                if let Some(bus) = get_bus_from_fake_name(&name) {
982                    tracing::trace!(
983                        member = member,
984                        name = %name,
985                        bus = ?bus,
986                        "Routing {} by fake unique name",
987                        member
988                    );
989                    return RouteDecision::Single(bus);
990                }
991            }
992            // Unknown unique name, default to sandbox
993            RouteDecision::Single(Bus::Sandbox)
994        }
995
996        // ListNames/ListActivatableNames: merge results from both buses
997        "ListNames" | "ListActivatableNames" => {
998            tracing::trace!(
999                member = member,
1000                "Routing {} to both buses for merge",
1001                member
1002            );
1003            RouteDecision::Merge
1004        }
1005
1006        // Other methods: default to sandbox
1007        _ => RouteDecision::Single(Bus::Sandbox),
1008    }
1009}
1010
1011/// Extract service name from an interface string.
1012/// e.g., "org.fcitx.Fcitx5.Controller1" -> "org.fcitx.Fcitx5"
1013fn extract_service_from_interface(interface: &str) -> Option<String> {
1014    // Typically service name is the first 3 parts of the interface
1015    // But this is heuristic - we try different lengths
1016    let parts: Vec<&str> = interface.split('.').collect();
1017    if parts.len() >= 3 {
1018        // Try org.foo.Bar first (3 parts)
1019        Some(parts[..3].join("."))
1020    } else {
1021        None
1022    }
1023}
1024
1025/// Parsed Unix socket address.
1026#[derive(Debug, Clone)]
1027enum UnixAddress {
1028    /// Filesystem path socket
1029    Path(std::path::PathBuf),
1030    /// Abstract socket (Linux only)
1031    Abstract(String),
1032}
1033
1034/// Connect to a D-Bus address.
1035/// Supports "unix:path=/path/to/socket" and "unix:abstract=name" formats.
1036async fn connect_dbus(addr: &str) -> Result<UnixStream> {
1037    let unix_addr = parse_unix_address(addr)?;
1038    match unix_addr {
1039        UnixAddress::Path(path) => {
1040            tracing::debug!(path = %path.display(), "Connecting to D-Bus (path)");
1041            let stream = UnixStream::connect(&path).await?;
1042            Ok(stream)
1043        }
1044        UnixAddress::Abstract(name) => {
1045            tracing::debug!(name = %name, "Connecting to D-Bus (abstract)");
1046            let stream = connect_abstract(&name).await?;
1047            Ok(stream)
1048        }
1049    }
1050}
1051
1052/// Connect to an abstract Unix socket (Linux only).
1053#[cfg(target_os = "linux")]
1054async fn connect_abstract(name: &str) -> Result<UnixStream> {
1055    use std::os::linux::net::SocketAddrExt;
1056
1057    let addr = std::os::unix::net::SocketAddr::from_abstract_name(name)?;
1058    let std_stream = std::os::unix::net::UnixStream::connect_addr(&addr)?;
1059    std_stream.set_nonblocking(true)?;
1060    let stream = UnixStream::from_std(std_stream)?;
1061    Ok(stream)
1062}
1063
1064#[cfg(not(target_os = "linux"))]
1065async fn connect_abstract(_name: &str) -> Result<UnixStream> {
1066    bail!("Abstract sockets are only supported on Linux");
1067}
1068
1069/// Read a single line from a D-Bus auth handshake (terminated by CRLF).
1070async fn read_auth_line(stream: &mut UnixStream) -> Result<String> {
1071    use tokio::io::AsyncReadExt;
1072
1073    let mut buf = [0u8; 256];
1074    let mut response = Vec::new();
1075    loop {
1076        let n = stream.read(&mut buf).await?;
1077        if n == 0 {
1078            bail!("Connection closed during auth");
1079        }
1080        response.extend_from_slice(&buf[..n]);
1081        if response.windows(2).any(|w| w == b"\r\n") {
1082            break;
1083        }
1084    }
1085    Ok(String::from_utf8_lossy(&response).into_owned())
1086}
1087
1088/// Parse a D-Bus address string to extract the Unix socket address.
1089/// Supports formats:
1090/// - unix:path=/path/to/socket
1091/// - unix:abstract=name
1092fn parse_unix_address(addr: &str) -> Result<UnixAddress> {
1093    if !addr.starts_with("unix:") {
1094        bail!("Only unix: addresses are supported, got: {}", addr);
1095    }
1096
1097    let parts = &addr[5..]; // Skip "unix:"
1098
1099    for part in parts.split(',') {
1100        if let Some(path) = part.strip_prefix("path=") {
1101            return Ok(UnixAddress::Path(Path::new(path).to_path_buf()));
1102        }
1103        if let Some(name) = part.strip_prefix("abstract=") {
1104            return Ok(UnixAddress::Abstract(name.to_string()));
1105        }
1106    }
1107
1108    bail!("No path= or abstract= found in address: {}", addr);
1109}
1110
1111#[cfg(test)]
1112mod tests {
1113    use super::*;
1114
1115    #[test]
1116    fn test_parse_unix_address_path() {
1117        let addr = parse_unix_address("unix:path=/run/user/1000/bus").unwrap();
1118        assert!(matches!(addr, UnixAddress::Path(p) if p == Path::new("/run/user/1000/bus")));
1119
1120        let addr = parse_unix_address("unix:path=/tmp/test.sock,guid=abc123").unwrap();
1121        assert!(matches!(addr, UnixAddress::Path(p) if p == Path::new("/tmp/test.sock")));
1122    }
1123
1124    #[test]
1125    fn test_parse_unix_address_abstract() {
1126        let addr = parse_unix_address("unix:abstract=/tmp/dbus-test").unwrap();
1127        assert!(matches!(addr, UnixAddress::Abstract(n) if n == "/tmp/dbus-test"));
1128
1129        let addr = parse_unix_address("unix:abstract=dbus-session,guid=abc").unwrap();
1130        assert!(matches!(addr, UnixAddress::Abstract(n) if n == "dbus-session"));
1131    }
1132
1133    #[test]
1134    fn test_parse_invalid_address() {
1135        assert!(parse_unix_address("tcp:host=localhost").is_err());
1136        assert!(parse_unix_address("unix:guid=abc123").is_err()); // no path or abstract
1137    }
1138}