Skip to main content

dbus_router/
session.rs

1//! Client session handling with dual upstream connections and routing
2
3use crate::auth;
4use crate::config::Config;
5use crate::dbus_daemon::{
6    build_list_names_response, merge_list_names, needs_request_rewrite, needs_response_rewrite,
7    parse_string_array, rewrite_match_rule_body, rewrite_name_owner_changed,
8    rewrite_single_name_response, rewrite_string_array_response, rewrite_unique_name_request,
9    signal_needs_rewrite,
10};
11use crate::fake_name::get_bus_from_fake_name;
12use crate::message::{self, read_message, Message, MessageType};
13use crate::message_rewrite::{parse_match_rule, rewrite_message_header, RewriteDirection};
14use anyhow::{bail, Result};
15use std::collections::{HashMap, HashSet};
16use std::path::{Path, PathBuf};
17use std::sync::Arc;
18use tokio::io::AsyncWriteExt;
19use tokio::net::UnixStream;
20
21/// Target bus for routing.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum Bus {
24    Host,
25    Sandbox,
26}
27
28/// Log a message with all relevant fields for debugging.
29fn log_message(msg: &Message, direction: &str, target: Option<Bus>) {
30    let target_str = match target {
31        Some(Bus::Host) => " -> Host",
32        Some(Bus::Sandbox) => " -> Sandbox",
33        None => "",
34    };
35
36    tracing::trace!(
37        direction = direction,
38        target = target_str,
39        msg_type = ?msg.header.msg_type,
40        serial = msg.header.serial,
41        reply_serial = ?msg.header.reply_serial,
42        sender = ?msg.header.sender,
43        destination = ?msg.header.destination,
44        interface = ?msg.header.interface,
45        member = ?msg.header.member,
46        body_len = msg.header.body_len,
47        "{}",
48        format_message_summary(msg)
49    );
50}
51
52/// Format a one-line summary of the message for logging.
53fn format_message_summary(msg: &Message) -> String {
54    let type_str = match msg.header.msg_type {
55        MessageType::MethodCall => "CALL",
56        MessageType::MethodReturn => "REPLY",
57        MessageType::Error => "ERROR",
58        MessageType::Signal => "SIGNAL",
59        MessageType::Invalid => "INVALID",
60    };
61
62    let interface = msg.header.interface.as_deref().unwrap_or("-");
63    let member = msg.header.member.as_deref().unwrap_or("-");
64    let dest = msg.header.destination.as_deref().unwrap_or("-");
65    let sender = msg.header.sender.as_deref().unwrap_or("-");
66
67    format!(
68        "{} {}.{} [{}->{}] serial={}",
69        type_str, interface, member, sender, dest, msg.header.serial
70    )
71}
72
73/// Prepare a message from an upstream bus for forwarding to the client.
74/// Rewrites sender header and body as needed.
75fn prepare_message_for_client(
76    msg: &Message,
77    source_bus: Bus,
78    pending_calls: &HashMap<u32, PendingCallInfo>,
79) -> Vec<u8> {
80    let mut msg_for_client = msg.clone();
81
82    // Log message context for debugging header rewriting issues
83    tracing::trace!(
84        msg_type = ?msg.header.msg_type,
85        sender = ?msg.header.sender,
86        destination = ?msg.header.destination,
87        interface = ?msg.header.interface,
88        member = ?msg.header.member,
89        source_bus = ?source_bus,
90        "Preparing message for client"
91    );
92
93    // Rewrite sender header to add bus prefix
94    if let Err(e) =
95        rewrite_message_header(&mut msg_for_client, RewriteDirection::ToClient, source_bus)
96    {
97        tracing::warn!(
98            error = %e,
99            msg_type = ?msg.header.msg_type,
100            sender = ?msg.header.sender,
101            destination = ?msg.header.destination,
102            interface = ?msg.header.interface,
103            member = ?msg.header.member,
104            "Failed to rewrite sender header"
105        );
106    }
107
108    // Handle body rewriting based on message type
109    match msg_for_client.header.msg_type {
110        MessageType::MethodReturn => {
111            rewrite_method_return_body(&msg_for_client, source_bus, pending_calls)
112        }
113        MessageType::Signal => rewrite_signal_body(&msg_for_client, source_bus),
114        _ => msg_for_client.raw,
115    }
116}
117
118/// Rewrite method return body if it contains unique names that need prefixing.
119fn rewrite_method_return_body(
120    msg: &Message,
121    source_bus: Bus,
122    pending_calls: &HashMap<u32, PendingCallInfo>,
123) -> Vec<u8> {
124    let Some(reply_serial) = msg.header.reply_serial else {
125        return msg.raw.clone();
126    };
127
128    let Some(call_info) = pending_calls.get(&reply_serial) else {
129        return msg.raw.clone();
130    };
131
132    if call_info.bus != source_bus {
133        return msg.raw.clone();
134    }
135
136    let Some(ref member) = call_info.member else {
137        return msg.raw.clone();
138    };
139
140    if !needs_response_rewrite(member) {
141        return msg.raw.clone();
142    }
143
144    // ListQueuedOwners returns an array of unique names
145    let result = if member == "ListQueuedOwners" {
146        rewrite_string_array_response(msg, source_bus)
147    } else {
148        rewrite_single_name_response(msg, source_bus)
149    };
150
151    match result {
152        Ok(rewritten) => {
153            tracing::trace!(member = member, bus = ?source_bus, "Rewrote response body");
154            rewritten
155        }
156        Err(e) => {
157            tracing::warn!(member = member, error = %e, "Failed to rewrite response body");
158            msg.raw.clone()
159        }
160    }
161}
162
163/// Rewrite signal body if it contains unique names that need prefixing.
164fn rewrite_signal_body(msg: &Message, source_bus: Bus) -> Vec<u8> {
165    let Some(ref member) = msg.header.member else {
166        return msg.raw.clone();
167    };
168
169    if !signal_needs_rewrite(member) {
170        return msg.raw.clone();
171    }
172
173    match rewrite_name_owner_changed(msg, source_bus) {
174        Ok(rewritten) => {
175            tracing::trace!(member = member, bus = ?source_bus, "Rewrote signal body");
176            rewritten
177        }
178        Err(e) => {
179            tracing::warn!(member = member, error = %e, "Failed to rewrite signal body");
180            msg.raw.clone()
181        }
182    }
183}
184
185/// Result of processing a merge response.
186enum MergeResult {
187    /// First response stored, waiting for second
188    Stored,
189    /// Second response received, returns merged message bytes
190    Complete(Vec<u8>),
191}
192
193/// Process a ListNames/ListActivatableNames merge response.
194/// Returns `Some(MergeResult)` if the message was a merge response, `None` otherwise.
195fn process_merge_response(
196    msg: &Message,
197    source_bus: Bus,
198    pending_merges: &mut HashMap<u32, PendingMerge>,
199) -> Option<MergeResult> {
200    let reply_serial = msg.header.reply_serial?;
201    let pending = pending_merges.get_mut(&reply_serial)?;
202
203    let names = parse_string_array(msg).unwrap_or_default();
204    tracing::trace!(
205        serial = reply_serial,
206        count = names.len(),
207        bus = ?source_bus,
208        "Received ListNames response"
209    );
210
211    let Some((first_bus, first_names)) = pending.first_response.take() else {
212        // First response - store it
213        pending.first_response = Some((source_bus, names));
214        return Some(MergeResult::Stored);
215    };
216
217    // Second response - merge and build final message
218    let (host_names, sandbox_names) = if first_bus == Bus::Host {
219        (first_names, names)
220    } else {
221        (names, first_names)
222    };
223
224    let merged = merge_list_names(host_names, sandbox_names);
225    tracing::trace!(
226        serial = reply_serial,
227        count = merged.len(),
228        "Merged ListNames response"
229    );
230
231    let pending = pending_merges.remove(&reply_serial).unwrap();
232    let response = match build_list_names_response(&pending.original_request, merged) {
233        Ok(response) => response,
234        Err(e) => {
235            tracing::warn!(error = %e, "Failed to build merged response");
236            msg.raw.clone()
237        }
238    };
239
240    Some(MergeResult::Complete(response))
241}
242
243/// Pending merge state for ListNames/ListActivatableNames
244#[derive(Debug)]
245struct PendingMerge {
246    /// The original request message (for building response)
247    original_request: Message,
248    /// First response received (None if waiting for first)
249    first_response: Option<(Bus, Vec<String>)>,
250}
251
252/// Info about a pending call (for response rewriting)
253#[derive(Debug, Clone)]
254struct PendingCallInfo {
255    /// Which bus the call was sent to
256    bus: Bus,
257    /// The member (method name) of the call (for response rewriting)
258    member: Option<String>,
259}
260
261/// A client session with connections to both upstream buses.
262pub struct Session {
263    /// Connection from sandbox app
264    client: UnixStream,
265    /// Connection to host session bus
266    host_bus: UnixStream,
267    /// Connection to sandbox session bus (default target)
268    sandbox_bus: UnixStream,
269    /// Routing configuration
270    config: Arc<Config>,
271    /// Track which bus each outgoing serial was sent to (for routing replies)
272    pending_calls: HashMap<u32, PendingCallInfo>,
273    /// Client process executable path (for sandbox export permission check)
274    client_exe_path: Option<PathBuf>,
275    /// Services exported by this client to the host bus
276    exported_services: HashSet<String>,
277    /// Services registered by this client on the sandbox bus
278    sandbox_services: HashSet<String>,
279    /// Track incoming calls from upstream buses (serial -> source bus)
280    incoming_calls: HashMap<u32, Bus>,
281    /// Pending ListNames/ListActivatableNames merges (serial -> state)
282    pending_merges: HashMap<u32, PendingMerge>,
283}
284
285/// Get the executable path of a peer process from a Unix socket.
286#[cfg(target_os = "linux")]
287fn get_peer_exe_path(stream: &UnixStream) -> Option<PathBuf> {
288    use std::os::unix::io::AsRawFd;
289
290    let fd = stream.as_raw_fd();
291
292    // Get peer credentials using SO_PEERCRED
293    let mut ucred: libc::ucred = unsafe { std::mem::zeroed() };
294    let mut len = std::mem::size_of::<libc::ucred>() as libc::socklen_t;
295
296    let ret = unsafe {
297        libc::getsockopt(
298            fd,
299            libc::SOL_SOCKET,
300            libc::SO_PEERCRED,
301            &mut ucred as *mut _ as *mut libc::c_void,
302            &mut len,
303        )
304    };
305
306    if ret != 0 {
307        tracing::debug!(
308            "Failed to get peer credentials: {}",
309            std::io::Error::last_os_error()
310        );
311        return None;
312    }
313
314    let pid = ucred.pid;
315    if pid <= 0 {
316        return None;
317    }
318
319    // Read /proc/{pid}/exe symlink
320    let exe_path = format!("/proc/{}/exe", pid);
321    match std::fs::read_link(&exe_path) {
322        Ok(path) => {
323            tracing::debug!(pid = pid, exe = %path.display(), "Got peer exe path");
324            Some(path)
325        }
326        Err(e) => {
327            tracing::debug!(pid = pid, error = %e, "Failed to read exe path");
328            None
329        }
330    }
331}
332
333#[cfg(not(target_os = "linux"))]
334fn get_peer_exe_path(_stream: &UnixStream) -> Option<PathBuf> {
335    None
336}
337
338impl Session {
339    /// Create a new session with connections to both upstream buses.
340    pub async fn new(
341        client: UnixStream,
342        host_addr: &str,
343        sandbox_addr: &str,
344        config: Arc<Config>,
345    ) -> Result<Self> {
346        let client_exe_path = get_peer_exe_path(&client);
347        let host_bus = connect_dbus(host_addr).await?;
348        let sandbox_bus = connect_dbus(sandbox_addr).await?;
349
350        Ok(Self {
351            client,
352            host_bus,
353            sandbox_bus,
354            config,
355            pending_calls: HashMap::new(),
356            client_exe_path,
357            exported_services: HashSet::new(),
358            sandbox_services: HashSet::new(),
359            incoming_calls: HashMap::new(),
360            pending_merges: HashMap::new(),
361        })
362    }
363
364    /// Run the session: authenticate with both buses, then forward messages.
365    pub async fn run(mut self) -> Result<()> {
366        // Check if this is a hostpass client
367        let is_hostpass = self
368            .client_exe_path
369            .as_ref()
370            .map(|p| self.config.has_hostpass(p))
371            .unwrap_or(false);
372
373        // Phase 1: Auth passthrough with sandbox bus
374        tracing::debug!("Starting auth phase with sandbox bus");
375        auth::auth_passthrough(&mut self.client, &mut self.sandbox_bus).await?;
376        tracing::info!("Auth with sandbox bus completed");
377
378        // Phase 1b: Also authenticate with host bus (using same credentials)
379        // For hostpass clients, skip Hello() - their Hello() will be forwarded to host bus
380        // For non-hostpass clients, send Hello() so host_routes messages can be routed
381        tracing::debug!("Starting auth phase with host bus");
382        self.auth_host_bus(is_hostpass).await?;
383        tracing::info!("Auth with host bus completed, starting message forwarding");
384
385        // Phase 2: Message forwarding with routing
386        self.forward_loop().await
387    }
388
389    /// Authenticate with the host bus.
390    /// The host bus needs its own auth handshake.
391    /// If `skip_hello` is true, skip the Hello() call (for hostpass clients whose
392    /// Hello() will be forwarded to the host bus).
393    async fn auth_host_bus(&mut self, skip_hello: bool) -> Result<()> {
394        // Send null byte and EXTERNAL auth with hex-encoded UID
395        self.host_bus.write_all(&[0]).await?;
396        let uid = unsafe { libc::getuid() };
397        let uid_hex: String = uid
398            .to_string()
399            .bytes()
400            .map(|b| format!("{:02x}", b))
401            .collect();
402        self.host_bus
403            .write_all(format!("AUTH EXTERNAL {}\r\n", uid_hex).as_bytes())
404            .await?;
405
406        // Read auth response
407        let response = read_auth_line(&mut self.host_bus).await?;
408        if !response.starts_with("OK") {
409            bail!("Host bus auth failed: {}", response.trim());
410        }
411
412        // Negotiate UNIX FD passing
413        self.host_bus.write_all(b"NEGOTIATE_UNIX_FD\r\n").await?;
414        let response = read_auth_line(&mut self.host_bus).await?;
415        tracing::debug!(response = %response.trim(), "Host bus NEGOTIATE_UNIX_FD response");
416
417        // Send BEGIN to complete SASL auth
418        self.host_bus.write_all(b"BEGIN\r\n").await?;
419
420        if skip_hello {
421            tracing::debug!("Skipping Hello() for hostpass client");
422            Ok(())
423        } else {
424            self.send_host_hello().await
425        }
426    }
427
428    /// Send Hello() method call to host bus and read the response.
429    /// This registers the router's connection with the host bus daemon.
430    async fn send_host_hello(&mut self) -> Result<()> {
431        use zvariant::{serialized::Context, to_bytes, ObjectPath, Value, LE};
432
433        // Build header fields array for Hello() call
434        let path = ObjectPath::try_from("/org/freedesktop/DBus").unwrap();
435        let fields: Vec<(u8, Value)> = vec![
436            (1, Value::ObjectPath(path)),                   // PATH
437            (2, Value::Str("org.freedesktop.DBus".into())), // INTERFACE
438            (3, Value::Str("Hello".into())),                // MEMBER
439            (6, Value::Str("org.freedesktop.DBus".into())), // DESTINATION
440        ];
441
442        let ctxt = Context::new_dbus(LE, 12);
443        let fields_encoded = to_bytes(ctxt, &fields)?;
444        let array_len = fields_encoded.len() - 4; // Exclude 4-byte length prefix
445
446        // Calculate padding to 8-byte boundary
447        let header_end = 16 + array_len;
448        let padding = (8 - (header_end % 8)) % 8;
449
450        // Build D-Bus message: fixed header + fields + padding
451        let mut msg = Vec::with_capacity(16 + array_len + padding);
452        msg.extend_from_slice(&[b'l', 1, 0, 1]); // endian, method_call, flags, version
453        msg.extend_from_slice(&0u32.to_le_bytes()); // body length
454        msg.extend_from_slice(&1u32.to_le_bytes()); // serial
455        msg.extend_from_slice(&fields_encoded);
456        msg.resize(msg.len() + padding, 0);
457
458        self.host_bus.write_all(&msg).await?;
459
460        // Read and validate response.
461        // After Hello(), the bus daemon sends:
462        // 1. MethodReturn with our unique name
463        // 2. NameAcquired signal for that name
464        // We must consume both to prevent the signal from being forwarded to client,
465        // which would cause "Unexpected message Signal" errors in strict clients like zbus.
466        //
467        // First, read the MethodReturn
468        match message::read_message(&mut self.host_bus).await? {
469            Some(resp) if resp.header.msg_type == MessageType::MethodReturn => {
470                tracing::debug!("Host bus Hello() MethodReturn received");
471            }
472            Some(resp) if resp.header.msg_type == MessageType::Error => {
473                bail!("Host bus Hello() failed with error")
474            }
475            Some(resp) => bail!("Unexpected response to Hello(): {:?}", resp.header.msg_type),
476            None => bail!("Host bus disconnected after Hello()"),
477        }
478
479        // Then, read the NameAcquired signal
480        match message::read_message(&mut self.host_bus).await? {
481            Some(resp) if resp.header.msg_type == MessageType::Signal => {
482                tracing::debug!(
483                    interface = ?resp.header.interface,
484                    member = ?resp.header.member,
485                    "Consumed NameAcquired signal after Hello()"
486                );
487            }
488            Some(resp) => {
489                tracing::warn!(
490                    msg_type = ?resp.header.msg_type,
491                    "Unexpected second message after Hello(), expected Signal"
492                );
493            }
494            None => bail!("Host bus disconnected after Hello()"),
495        }
496
497        tracing::debug!("Host bus Hello() completed");
498        Ok(())
499    }
500
501    /// Forward messages between client and upstream buses with routing.
502    async fn forward_loop(mut self) -> Result<()> {
503        // Check if this is a hostpass client (they only use host bus)
504        let is_hostpass = self
505            .client_exe_path
506            .as_ref()
507            .map(|p| self.config.has_hostpass(p))
508            .unwrap_or(false);
509
510        let (client_read, mut client_write) = self.client.split();
511        let (host_read, mut host_write) = self.host_bus.split();
512        let (sandbox_read, mut sandbox_write) = self.sandbox_bus.split();
513
514        let mut client_read = tokio::io::BufReader::new(client_read);
515        let mut host_read = tokio::io::BufReader::new(host_read);
516        let mut sandbox_read = tokio::io::BufReader::new(sandbox_read);
517
518        // Track if sandbox bus is still active (for hostpass clients that survive sandbox disconnect)
519        let mut sandbox_active = true;
520
521        loop {
522            tokio::select! {
523                biased;
524                // Read from client and route to appropriate bus
525                result = read_message(&mut client_read) => {
526                    match result {
527                        Ok(Some(msg)) => {
528                            // Check if this is a reply to an incoming call from upstream bus
529                            if matches!(msg.header.msg_type, MessageType::MethodReturn | MessageType::Error) {
530                                if let Some(reply_serial) = msg.header.reply_serial {
531                                    if let Some(source_bus) = self.incoming_calls.remove(&reply_serial) {
532                                        // Rewrite destination from fake name to real name
533                                        let mut msg_for_bus = msg.clone();
534                                        if let Err(e) = rewrite_message_header(
535                                            &mut msg_for_bus,
536                                            RewriteDirection::ToUpstream,
537                                            source_bus,
538                                        ) {
539                                            tracing::warn!(error = %e, "Failed to rewrite destination for reply");
540                                        }
541
542                                        match source_bus {
543                                            Bus::Host => {
544                                                log_message(&msg, "Client->Host(reply)", Some(Bus::Host));
545                                                host_write.write_all(&msg_for_bus.raw).await?;
546                                            }
547                                            Bus::Sandbox => {
548                                                log_message(&msg, "Client->Sandbox(reply)", Some(Bus::Sandbox));
549                                                sandbox_write.write_all(&msg_for_bus.raw).await?;
550                                            }
551                                        }
552                                        continue;
553                                    }
554                                }
555                            }
556
557                            let decision = route_request(&self.config, &msg, self.client_exe_path.as_deref());
558
559                            // Track RequestName calls to know which services this client exports
560                            if msg.is_request_name() {
561                                if let Some(name) = msg.extract_name_from_body() {
562                                    match decision {
563                                        RouteDecision::Single(Bus::Host) => {
564                                            tracing::info!(service = %name, "Client exporting service to host bus");
565                                            self.exported_services.insert(name);
566                                        }
567                                        RouteDecision::Single(Bus::Sandbox) => {
568                                            tracing::info!(service = %name, "Client registering service on sandbox bus");
569                                            self.sandbox_services.insert(name);
570                                        }
571                                        _ => {}
572                                    }
573                                }
574                            }
575
576                            match decision {
577                                RouteDecision::Single(target) => {
578                                    // Prepare message for sending to upstream
579                                    let mut msg_for_upstream = msg.clone();
580
581                                    // 1. Rewrite destination header if it's a fake unique name
582                                    if let Err(e) = rewrite_message_header(
583                                        &mut msg_for_upstream,
584                                        RewriteDirection::ToUpstream,
585                                        target, // source_bus not used for ToUpstream
586                                    ) {
587                                        tracing::warn!(error = %e, "Failed to rewrite destination header");
588                                    }
589
590                                    // 2. Rewrite body for org.freedesktop.DBus methods
591                                    let msg_to_send = if msg_for_upstream.header.destination.as_deref() == Some("org.freedesktop.DBus") {
592                                        if let Some(member) = msg_for_upstream.header.member.as_deref() {
593                                            if needs_request_rewrite(member) {
594                                                // Rewrite unique name in body (GetConnectionCredentials etc.)
595                                                match rewrite_unique_name_request(&msg_for_upstream) {
596                                                    Ok((rewritten, _bus)) => {
597                                                        tracing::trace!(
598                                                            member = member,
599                                                            "Rewrote request body to remove fake prefix"
600                                                        );
601                                                        rewritten
602                                                    }
603                                                    Err(e) => {
604                                                        tracing::warn!(
605                                                            member = member,
606                                                            error = %e,
607                                                            "Failed to rewrite request body"
608                                                        );
609                                                        msg_for_upstream.raw.clone()
610                                                    }
611                                                }
612                                            } else if member == "AddMatch" || member == "RemoveMatch" {
613                                                // Rewrite sender in match rule
614                                                match rewrite_match_rule_body(&msg_for_upstream) {
615                                                    Ok(Some(rewritten)) => {
616                                                        tracing::trace!(
617                                                            member = member,
618                                                            "Rewrote match rule sender"
619                                                        );
620                                                        rewritten
621                                                    }
622                                                    Ok(None) => msg_for_upstream.raw.clone(),
623                                                    Err(e) => {
624                                                        tracing::warn!(
625                                                            member = member,
626                                                            error = %e,
627                                                            "Failed to rewrite match rule"
628                                                        );
629                                                        msg_for_upstream.raw.clone()
630                                                    }
631                                                }
632                                            } else {
633                                                msg_for_upstream.raw.clone()
634                                            }
635                                        } else {
636                                            msg_for_upstream.raw.clone()
637                                        }
638                                    } else {
639                                        msg_for_upstream.raw.clone()
640                                    };
641
642                                    self.pending_calls.insert(msg.header.serial, PendingCallInfo {
643                                        bus: target,
644                                        member: msg.header.member.clone(),
645                                    });
646                                    log_message(&msg, "Client", Some(target));
647
648                                    match target {
649                                        Bus::Host => host_write.write_all(&msg_to_send).await?,
650                                        Bus::Sandbox => sandbox_write.write_all(&msg_to_send).await?,
651                                    }
652                                }
653                                RouteDecision::Both => {
654                                    // Send to both buses (e.g., AddMatch without sender)
655                                    // Rewrite match rule body if needed
656                                    let msg_to_send = if msg.header.member.as_deref() == Some("AddMatch")
657                                        || msg.header.member.as_deref() == Some("RemoveMatch")
658                                    {
659                                        match rewrite_match_rule_body(&msg) {
660                                            Ok(Some(rewritten)) => rewritten,
661                                            Ok(None) => msg.raw.clone(),
662                                            Err(_) => msg.raw.clone(),
663                                        }
664                                    } else {
665                                        msg.raw.clone()
666                                    };
667
668                                    self.pending_calls.insert(msg.header.serial, PendingCallInfo {
669                                        bus: Bus::Sandbox,
670                                        member: msg.header.member.clone(),
671                                    });
672                                    log_message(&msg, "Client->Both", None);
673
674                                    host_write.write_all(&msg_to_send).await?;
675                                    sandbox_write.write_all(&msg_to_send).await?;
676                                }
677                                RouteDecision::Merge => {
678                                    // ListNames/ListActivatableNames: send to both and merge results
679                                    log_message(&msg, "Client->Merge", None);
680
681                                    self.pending_merges.insert(msg.header.serial, PendingMerge {
682                                        original_request: msg.clone(),
683                                        first_response: None,
684                                    });
685
686                                    host_write.write_all(&msg.raw).await?;
687                                    sandbox_write.write_all(&msg.raw).await?;
688                                }
689                            }
690                        }
691                        Ok(None) => {
692                            tracing::debug!("Client disconnected");
693                            return Ok(());
694                        }
695                        Err(e) => {
696                            tracing::debug!(error = %e, "Error reading from client");
697                            return Ok(());
698                        }
699                    }
700                }
701
702                // Read from host bus and forward to client
703                result = read_message(&mut host_read) => {
704                    match result {
705                        Ok(Some(msg)) => {
706                            // Check if this is a call to an exported service
707                            if msg.header.msg_type == MessageType::MethodCall {
708                                if let Some(ref dest) = msg.header.destination {
709                                    if self.exported_services.contains(dest) {
710                                        log_message(&msg, "Host->Client(exported)", None);
711                                        self.incoming_calls.insert(msg.header.serial, Bus::Host);
712                                        client_write.write_all(&msg.raw).await?;
713                                        continue;
714                                    }
715                                }
716                            }
717
718                            // Check if this is a response to a pending merge request
719                            if let Some(result) = process_merge_response(&msg, Bus::Host, &mut self.pending_merges) {
720                                if let MergeResult::Complete(response) = result {
721                                    client_write.write_all(&response).await?;
722                                }
723                                continue;
724                            }
725
726                            let msg_to_send = prepare_message_for_client(
727                                &msg,
728                                Bus::Host,
729                                &self.pending_calls,
730                            );
731                            log_message(&msg, "Host->Client", None);
732                            client_write.write_all(&msg_to_send).await?;
733                        }
734                        Ok(None) => {
735                            tracing::debug!("Host bus disconnected");
736                            return Ok(());
737                        }
738                        Err(e) => {
739                            tracing::debug!(error = %e, "Error reading from host bus");
740                            return Ok(());
741                        }
742                    }
743                }
744
745                // Read from sandbox bus and forward to client
746                result = read_message(&mut sandbox_read), if sandbox_active => {
747                    match result {
748                        Ok(Some(msg)) => {
749                            // Check if this is a response to a pending merge request
750                            if let Some(result) = process_merge_response(&msg, Bus::Sandbox, &mut self.pending_merges) {
751                                if let MergeResult::Complete(response) = result {
752                                    client_write.write_all(&response).await?;
753                                }
754                                continue;
755                            }
756
757                            // Check if this is a call to a sandbox service registered by this client
758                            if msg.header.msg_type == MessageType::MethodCall {
759                                if let Some(ref dest) = msg.header.destination {
760                                    if self.sandbox_services.contains(dest) {
761                                        // Rewrite sender to add :s. prefix before forwarding to client
762                                        let mut msg_for_client = msg.clone();
763                                        if let Err(e) = rewrite_message_header(
764                                            &mut msg_for_client,
765                                            RewriteDirection::ToClient,
766                                            Bus::Sandbox,
767                                        ) {
768                                            tracing::warn!(error = %e, "Failed to rewrite sender header for sandbox service call");
769                                        }
770
771                                        log_message(&msg, "Sandbox->Client(service)", None);
772                                        self.incoming_calls.insert(msg.header.serial, Bus::Sandbox);
773                                        client_write.write_all(&msg_for_client.raw).await?;
774                                        continue;
775                                    }
776                                }
777                            }
778
779                            let msg_to_send = prepare_message_for_client(
780                                &msg,
781                                Bus::Sandbox,
782                                &self.pending_calls,
783                            );
784                            log_message(&msg, "Sandbox->Client", None);
785                            client_write.write_all(&msg_to_send).await?;
786                        }
787                        Ok(None) => {
788                            tracing::debug!("Sandbox bus disconnected");
789                            if is_hostpass {
790                                // Hostpass clients only use host bus, so they can continue
791                                tracing::info!("Hostpass client continues after sandbox disconnect");
792                                sandbox_active = false;
793                                continue;
794                            }
795                            return Ok(());
796                        }
797                        Err(e) => {
798                            tracing::debug!(error = %e, "Error reading from sandbox bus");
799                            if is_hostpass {
800                                tracing::info!("Hostpass client continues after sandbox read error");
801                                sandbox_active = false;
802                                continue;
803                            }
804                            return Ok(());
805                        }
806                    }
807                }
808            }
809        }
810    }
811}
812
813/// Routing decision result
814#[derive(Debug, Clone)]
815pub enum RouteDecision {
816    /// Route to a single bus
817    Single(Bus),
818    /// Route to both buses (for AddMatch without sender)
819    Both,
820    /// Route to both buses and merge results (for ListNames)
821    Merge,
822}
823
824/// Determine which bus to route a request to based on destination.
825fn route_request(config: &Config, msg: &Message, client_exe: Option<&Path>) -> RouteDecision {
826    // Hostpass: route ALL messages from hostpass processes to host bus
827    // This is required because D-Bus requires Hello() before any other operations,
828    // and the host bus won't accept messages from connections that haven't called Hello()
829    if let Some(exe) = client_exe {
830        if config.has_hostpass(exe) {
831            tracing::trace!(exe = %exe.display(), "Routing to host (hostpass)");
832            return RouteDecision::Single(Bus::Host);
833        }
834    }
835
836    // Check if destination is a fake unique name (e.g., :h.1.45)
837    if let Some(ref dest) = msg.header.destination {
838        if let Some(bus) = get_bus_from_fake_name(dest) {
839            tracing::trace!(dest = %dest, bus = ?bus, "Routing by fake unique name");
840            return RouteDecision::Single(bus);
841        }
842    }
843
844    // Special handling for org.freedesktop.DBus methods
845    if msg.header.destination.as_deref() == Some("org.freedesktop.DBus")
846        && msg.header.interface.as_deref() == Some("org.freedesktop.DBus")
847    {
848        return route_dbus_daemon_call(config, msg);
849    }
850
851    // Method calls and signals with a destination are routed based on config
852    if let Some(ref dest) = msg.header.destination {
853        if config.should_route_to_host(dest) {
854            return RouteDecision::Single(Bus::Host);
855        }
856    }
857
858    // Default: route to sandbox bus
859    RouteDecision::Single(Bus::Sandbox)
860}
861
862/// Route calls to org.freedesktop.DBus based on method and arguments
863fn route_dbus_daemon_call(config: &Config, msg: &Message) -> RouteDecision {
864    let member = msg.header.member.as_deref().unwrap_or("");
865
866    match member {
867        // AddMatch/RemoveMatch: route based on sender in the match rule
868        "AddMatch" | "RemoveMatch" => {
869            if let Some(rule) = msg.extract_string_from_body() {
870                let info = parse_match_rule(&rule);
871
872                // If sender is specified, route based on sender
873                if let Some(ref sender) = info.sender {
874                    // Check if sender is a fake unique name
875                    if let Some(bus) = get_bus_from_fake_name(sender) {
876                        tracing::trace!(
877                            member = member,
878                            sender = %sender,
879                            bus = ?bus,
880                            "Routing {} by fake sender name",
881                            member
882                        );
883                        return RouteDecision::Single(bus);
884                    }
885
886                    // Check if sender matches host_routes
887                    if config.should_route_to_host(sender) {
888                        tracing::trace!(
889                            member = member,
890                            sender = %sender,
891                            "Routing {} to host (sender in host_routes)",
892                            member
893                        );
894                        return RouteDecision::Single(Bus::Host);
895                    }
896                }
897
898                // If interface is specified but no sender, try to route by interface
899                if let Some(ref interface) = info.interface {
900                    // Extract service name from interface (e.g., org.fcitx.Fcitx5.Controller1 -> org.fcitx.Fcitx5)
901                    if let Some(service) = extract_service_from_interface(interface) {
902                        if config.should_route_to_host(&service) {
903                            tracing::trace!(
904                                member = member,
905                                interface = %interface,
906                                service = %service,
907                                "Routing {} to host (interface matches host_routes)",
908                                member
909                            );
910                            return RouteDecision::Single(Bus::Host);
911                        }
912                    }
913                }
914
915                // No sender or interface to determine routing - send to both buses
916                tracing::trace!(
917                    member = member,
918                    rule = %rule,
919                    "Routing {} to both buses (no sender specified)",
920                    member
921                );
922                return RouteDecision::Both;
923            }
924
925            // Can't parse rule, default to sandbox
926            RouteDecision::Single(Bus::Sandbox)
927        }
928
929        // RequestName/ReleaseName: route based on the service name being registered
930        "RequestName" | "ReleaseName" => {
931            if let Some(name) = msg.extract_name_from_body() {
932                if config.should_route_to_host(&name) {
933                    tracing::trace!(
934                        member = member,
935                        name = %name,
936                        "Routing {} to host (name in host_routes)",
937                        member
938                    );
939                    return RouteDecision::Single(Bus::Host);
940                }
941            }
942            RouteDecision::Single(Bus::Sandbox)
943        }
944
945        // GetNameOwner/NameHasOwner/StartServiceByName: route based on queried name
946        "GetNameOwner" | "NameHasOwner" | "StartServiceByName" => {
947            if let Some(name) = msg.extract_name_from_body() {
948                // Check if name is a fake unique name
949                if let Some(bus) = get_bus_from_fake_name(&name) {
950                    return RouteDecision::Single(bus);
951                }
952
953                if config.should_route_to_host(&name) {
954                    tracing::trace!(
955                        member = member,
956                        name = %name,
957                        "Routing {} to host (name in host_routes)",
958                        member
959                    );
960                    return RouteDecision::Single(Bus::Host);
961                }
962            }
963            RouteDecision::Single(Bus::Sandbox)
964        }
965
966        // GetConnectionCredentials etc: route based on unique name argument
967        "GetConnectionCredentials" | "GetConnectionUnixUser" | "GetConnectionUnixProcessID" => {
968            if let Some(name) = msg.extract_name_from_body() {
969                if let Some(bus) = get_bus_from_fake_name(&name) {
970                    tracing::trace!(
971                        member = member,
972                        name = %name,
973                        bus = ?bus,
974                        "Routing {} by fake unique name",
975                        member
976                    );
977                    return RouteDecision::Single(bus);
978                }
979            }
980            // Unknown unique name, default to sandbox
981            RouteDecision::Single(Bus::Sandbox)
982        }
983
984        // ListNames/ListActivatableNames: merge results from both buses
985        "ListNames" | "ListActivatableNames" => {
986            tracing::trace!(
987                member = member,
988                "Routing {} to both buses for merge",
989                member
990            );
991            RouteDecision::Merge
992        }
993
994        // Other methods: default to sandbox
995        _ => RouteDecision::Single(Bus::Sandbox),
996    }
997}
998
999/// Extract service name from an interface string.
1000/// e.g., "org.fcitx.Fcitx5.Controller1" -> "org.fcitx.Fcitx5"
1001fn extract_service_from_interface(interface: &str) -> Option<String> {
1002    // Typically service name is the first 3 parts of the interface
1003    // But this is heuristic - we try different lengths
1004    let parts: Vec<&str> = interface.split('.').collect();
1005    if parts.len() >= 3 {
1006        // Try org.foo.Bar first (3 parts)
1007        Some(parts[..3].join("."))
1008    } else {
1009        None
1010    }
1011}
1012
1013/// Parsed Unix socket address.
1014#[derive(Debug, Clone)]
1015enum UnixAddress {
1016    /// Filesystem path socket
1017    Path(std::path::PathBuf),
1018    /// Abstract socket (Linux only)
1019    Abstract(String),
1020}
1021
1022/// Connect to a D-Bus address.
1023/// Supports "unix:path=/path/to/socket" and "unix:abstract=name" formats.
1024async fn connect_dbus(addr: &str) -> Result<UnixStream> {
1025    let unix_addr = parse_unix_address(addr)?;
1026    match unix_addr {
1027        UnixAddress::Path(path) => {
1028            tracing::debug!(path = %path.display(), "Connecting to D-Bus (path)");
1029            let stream = UnixStream::connect(&path).await?;
1030            Ok(stream)
1031        }
1032        UnixAddress::Abstract(name) => {
1033            tracing::debug!(name = %name, "Connecting to D-Bus (abstract)");
1034            let stream = connect_abstract(&name).await?;
1035            Ok(stream)
1036        }
1037    }
1038}
1039
1040/// Connect to an abstract Unix socket (Linux only).
1041#[cfg(target_os = "linux")]
1042async fn connect_abstract(name: &str) -> Result<UnixStream> {
1043    use std::os::linux::net::SocketAddrExt;
1044
1045    let addr = std::os::unix::net::SocketAddr::from_abstract_name(name)?;
1046    let std_stream = std::os::unix::net::UnixStream::connect_addr(&addr)?;
1047    std_stream.set_nonblocking(true)?;
1048    let stream = UnixStream::from_std(std_stream)?;
1049    Ok(stream)
1050}
1051
1052#[cfg(not(target_os = "linux"))]
1053async fn connect_abstract(_name: &str) -> Result<UnixStream> {
1054    bail!("Abstract sockets are only supported on Linux");
1055}
1056
1057/// Read a single line from a D-Bus auth handshake (terminated by CRLF).
1058async fn read_auth_line(stream: &mut UnixStream) -> Result<String> {
1059    use tokio::io::AsyncReadExt;
1060
1061    let mut buf = [0u8; 256];
1062    let mut response = Vec::new();
1063    loop {
1064        let n = stream.read(&mut buf).await?;
1065        if n == 0 {
1066            bail!("Connection closed during auth");
1067        }
1068        response.extend_from_slice(&buf[..n]);
1069        if response.windows(2).any(|w| w == b"\r\n") {
1070            break;
1071        }
1072    }
1073    Ok(String::from_utf8_lossy(&response).into_owned())
1074}
1075
1076/// Parse a D-Bus address string to extract the Unix socket address.
1077/// Supports formats:
1078/// - unix:path=/path/to/socket
1079/// - unix:abstract=name
1080fn parse_unix_address(addr: &str) -> Result<UnixAddress> {
1081    if !addr.starts_with("unix:") {
1082        bail!("Only unix: addresses are supported, got: {}", addr);
1083    }
1084
1085    let parts = &addr[5..]; // Skip "unix:"
1086
1087    for part in parts.split(',') {
1088        if let Some(path) = part.strip_prefix("path=") {
1089            return Ok(UnixAddress::Path(Path::new(path).to_path_buf()));
1090        }
1091        if let Some(name) = part.strip_prefix("abstract=") {
1092            return Ok(UnixAddress::Abstract(name.to_string()));
1093        }
1094    }
1095
1096    bail!("No path= or abstract= found in address: {}", addr);
1097}
1098
1099#[cfg(test)]
1100mod tests {
1101    use super::*;
1102
1103    #[test]
1104    fn test_parse_unix_address_path() {
1105        let addr = parse_unix_address("unix:path=/run/user/1000/bus").unwrap();
1106        assert!(matches!(addr, UnixAddress::Path(p) if p == Path::new("/run/user/1000/bus")));
1107
1108        let addr = parse_unix_address("unix:path=/tmp/test.sock,guid=abc123").unwrap();
1109        assert!(matches!(addr, UnixAddress::Path(p) if p == Path::new("/tmp/test.sock")));
1110    }
1111
1112    #[test]
1113    fn test_parse_unix_address_abstract() {
1114        let addr = parse_unix_address("unix:abstract=/tmp/dbus-test").unwrap();
1115        assert!(matches!(addr, UnixAddress::Abstract(n) if n == "/tmp/dbus-test"));
1116
1117        let addr = parse_unix_address("unix:abstract=dbus-session,guid=abc").unwrap();
1118        assert!(matches!(addr, UnixAddress::Abstract(n) if n == "dbus-session"));
1119    }
1120
1121    #[test]
1122    fn test_parse_invalid_address() {
1123        assert!(parse_unix_address("tcp:host=localhost").is_err());
1124        assert!(parse_unix_address("unix:guid=abc123").is_err()); // no path or abstract
1125    }
1126}