1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
//! Crate for discovering Holochain peers over MDNS
//! Works by broadcasting a service named `HC_SERVICE_NAME`
//! and adding base64 encoded data in a TXT record
//!
//! Uses libmdns crate for broadcasting
//! Uses mdns crate for discovery
use err_derive::Error;
use mdns::RecordKind;
use std::time::Duration;
use tokio_stream::{Stream, StreamExt};

use std::sync::atomic::{AtomicBool, Ordering};

const HC_SERVICE_PROTOCOL: &str = "._udp";
const BROADCAST_INTERVAL_SEC: u64 = 8;
const QUERY_INTERVAL_SEC: u64 = 5;
const MAX_TXT_SIZE: usize = 192;

#[derive(Debug, Error)]
pub enum MdnsError {
    #[error(display = "Regular Mdns error {}", _0)]
    Mdns(#[error(source)] mdns::Error),
    #[error(display = "Base64 decoding error {}", _0)]
    Base64(#[error(source)] base64::DecodeError),
}

/// Stop thread created by `mdns_create_broadcast_thread()`
pub fn mdns_kill_thread(can_run: ::std::sync::Arc<AtomicBool>) {
    can_run.store(false, Ordering::Relaxed);
}

/// Create a thread that will broadcast a holochain service over mdns
/// Returns Sender for sending thread termination command
pub fn mdns_create_broadcast_thread(
    service_type: String,
    service_name: String,
    buffer: &[u8],
) -> ::std::sync::Arc<AtomicBool> {
    let svc_type = format!("_{}{}", service_type, HC_SERVICE_PROTOCOL);
    // Constraints in libmdns
    assert!(
        svc_type.len() < 63,
        "len = {} ({}) ; {}",
        svc_type.len(),
        service_type.len(),
        service_type
    );
    assert!(service_name.len() < 63);
    // Create Termination command variable
    let can_run = ::std::sync::Arc::new(AtomicBool::new(true));
    let can_run_clone = can_run.clone();
    // Change buffer to base64 string
    let mut b64 = format!(
        "u{}",
        base64::encode_config(buffer, base64::URL_SAFE_NO_PAD)
    );
    //println!(
    //    "Broadcasting service type '{}', named '{}' over mdns ({})",
    //    svc_type,
    //    service_name,
    //    b64.len()
    //);
    // Create thread
    let _handle = tokio::task::spawn(async move {
        // Split buffer to fix TXT max size
        let mut substrs = Vec::new();
        while b64.len() > MAX_TXT_SIZE {
            let start: String = b64.drain(..MAX_TXT_SIZE).collect();
            substrs.push(start);
        }
        substrs.push(b64);
        let txts: Vec<_> = substrs.iter().map(AsRef::as_ref).collect();
        //println!("Entering mdns broadcasting thread...");
        // Create mdns responder

        let responder = libmdns::Responder::new().unwrap();
        let _svc = responder.register(svc_type, service_name, 0, &txts);
        // Loop forever unless termination command received
        loop {
            tokio::time::sleep(::std::time::Duration::from_secs(BROADCAST_INTERVAL_SEC)).await;
            if !can_run_clone.load(Ordering::Relaxed) {
                //println!("Terminating.");
                break;
            }
        }
    });
    // Done
    can_run
}

///
#[derive(Debug, Clone)]
pub struct MdnsResponse {
    /// Service type used
    pub service_type: String,
    /// Service name used
    pub service_name: String,
    /// IP address that responded to the mdns query
    pub addr: std::net::IpAddr,
    /// Data contained in the TXT record
    pub buffer: Vec<u8>,
}

/// Queries the network for the holochain service.
/// Returns an iterator over all responses received.
#[allow(clippy::let_and_return)]
pub fn mdns_listen(service_type: String) -> impl Stream<Item = Result<MdnsResponse, MdnsError>> {
    //let service_name = format!("{}.local", HC_SERVICE_TYPE);
    let svc_type = format!("_{}{}.local", service_type, HC_SERVICE_PROTOCOL);
    //println!("MDNS query for service type '{}'", svc_type);
    let query = mdns::discover::all(svc_type, Duration::from_secs(QUERY_INTERVAL_SEC))
        .expect("mdns Discover failed");
    // Get Mdns Response stream
    let response_stream = query.listen();
    // Change it into a MdnsResponse stream
    let mdns_stream = response_stream
        // Filtering out Empty responses
        .filter(move |res| {
            match res {
                Ok(response) => !response.is_empty() && response.ip_addr().is_some(),
                Err(_) => true, // Keep errors
            }
        })
        .map(|maybe_response| {
            if let Err(e) = maybe_response {
                return Err(MdnsError::Mdns(e));
            }
            let response = maybe_response.unwrap();
            // NOTE: if response.ip_addr() is not te right address,
            // we should give all A/AAA records found in the answers instead
            let addr = response.ip_addr().unwrap(); // should have already been filtered out
            let mut buffer = Vec::new();
            let mut service_name = String::new();
            let mut service_type = String::new();
            //println!("Response Answer count = {}", response.answers.len());
            for answer in response.answers {
                match answer.kind {
                    RecordKind::TXT(txts) => {
                        //println!("TXT count = {}", txts.len());
                        let mut b64 = String::new();
                        for txt in txts {
                            //println!("Response TXT = {:?}", txt);
                            b64.push_str(&txt);
                        }
                        buffer = match base64::decode_config(&b64[1..], base64::URL_SAFE_NO_PAD) {
                            Err(e) => return Err(MdnsError::Base64(e)),
                            Ok(s) => s,
                        };
                    }
                    // Retrieve service name stored in PTR record
                    RecordKind::PTR(ptr) => {
                        //println!("PTR = {}", ptr);
                        service_name = ptr
                            .split('.')
                            .next()
                            .expect("Found service without a name")
                            .to_string();
                        let names: Vec<&str> = answer.name.split("._").collect();
                        //println!("answer.name = {}", answer.name);
                        service_type = names[0][1..].to_string();
                    }
                    _ => {}
                }
            }
            Ok(MdnsResponse {
                service_type,
                service_name,
                addr,
                buffer,
            })
        });
    // Done
    mdns_stream
}