Skip to main content

rtmp_rs/client/
connector.rs

1//! RTMP client connector
2//!
3//! Low-level client for connecting to RTMP servers.
4
5use std::collections::HashMap;
6
7use bytes::{Buf, Bytes, BytesMut};
8use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
9use tokio::net::TcpStream;
10use tokio::time::timeout;
11
12use crate::amf::AmfValue;
13use crate::error::{Error, Result};
14use crate::protocol::chunk::{ChunkDecoder, ChunkEncoder, RtmpChunk};
15use crate::protocol::constants::*;
16use crate::protocol::enhanced::{EnhancedCapabilities, EnhancedRtmpMode};
17use crate::protocol::handshake::{Handshake, HandshakeRole};
18use crate::protocol::message::{Command, ConnectParams, RtmpMessage};
19
20use super::config::{ClientConfig, ParsedUrl};
21
22/// RTMP client connector
23pub struct RtmpConnector {
24    config: ClientConfig,
25    parsed_url: ParsedUrl,
26    reader: BufReader<tokio::io::ReadHalf<TcpStream>>,
27    writer: BufWriter<tokio::io::WriteHalf<TcpStream>>,
28    read_buf: BytesMut,
29    write_buf: BytesMut,
30    chunk_decoder: ChunkDecoder,
31    chunk_encoder: ChunkEncoder,
32    stream_id: u32,
33    /// Negotiated E-RTMP capabilities (if E-RTMP is active)
34    enhanced_capabilities: Option<EnhancedCapabilities>,
35}
36
37impl RtmpConnector {
38    /// Connect to an RTMP server
39    pub async fn connect(config: ClientConfig) -> Result<Self> {
40        let parsed_url = config
41            .parse_url()
42            .ok_or_else(|| Error::Config("Invalid RTMP URL".into()))?;
43
44        let addr = format!("{}:{}", parsed_url.host, parsed_url.port);
45
46        let socket = timeout(config.connect_timeout, TcpStream::connect(&addr))
47            .await
48            .map_err(|_| Error::Timeout)?
49            .map_err(Error::Io)?;
50
51        if config.tcp_nodelay {
52            socket.set_nodelay(true)?;
53        }
54
55        let (read_half, write_half) = tokio::io::split(socket);
56
57        let mut connector = Self {
58            config,
59            parsed_url,
60            reader: BufReader::with_capacity(64 * 1024, read_half),
61            writer: BufWriter::with_capacity(64 * 1024, write_half),
62            read_buf: BytesMut::with_capacity(64 * 1024),
63            write_buf: BytesMut::with_capacity(64 * 1024),
64            chunk_decoder: ChunkDecoder::new(),
65            chunk_encoder: ChunkEncoder::new(),
66            stream_id: 0,
67            enhanced_capabilities: None,
68        };
69
70        connector.do_handshake().await?;
71        connector.do_connect().await?;
72
73        Ok(connector)
74    }
75
76    /// Perform handshake
77    async fn do_handshake(&mut self) -> Result<()> {
78        let mut handshake = Handshake::new(HandshakeRole::Client);
79
80        // Send C0C1
81        let c0c1 = handshake.generate_initial().ok_or(Error::Protocol(
82            crate::error::ProtocolError::InvalidChunkHeader,
83        ))?;
84        self.writer.write_all(&c0c1).await?;
85        self.writer.flush().await?;
86
87        // Wait for S0S1S2 and send C2
88        let timeout_duration = self.config.connect_timeout;
89        timeout(timeout_duration, async {
90            loop {
91                let n = self.reader.read_buf(&mut self.read_buf).await?;
92                if n == 0 {
93                    return Err(Error::ConnectionClosed);
94                }
95
96                let mut buf = Bytes::copy_from_slice(&self.read_buf);
97                if let Some(response) = handshake.process(&mut buf)? {
98                    let consumed = self.read_buf.len() - buf.len();
99                    self.read_buf.advance(consumed);
100
101                    self.writer.write_all(&response).await?;
102                    self.writer.flush().await?;
103                }
104
105                if handshake.is_done() {
106                    break;
107                }
108            }
109            Ok::<_, Error>(())
110        })
111        .await
112        .map_err(|_| Error::Timeout)??;
113
114        Ok(())
115    }
116
117    /// Send connect command
118    async fn do_connect(&mut self) -> Result<()> {
119        let mut obj = HashMap::new();
120        obj.insert(
121            "app".to_string(),
122            AmfValue::String(self.parsed_url.app.clone()),
123        );
124        obj.insert("type".to_string(), AmfValue::String("nonprivate".into()));
125        obj.insert(
126            "flashVer".to_string(),
127            AmfValue::String(self.config.flash_ver.clone()),
128        );
129        obj.insert(
130            "tcUrl".to_string(),
131            AmfValue::String(self.config.url.clone()),
132        );
133        obj.insert("fpad".to_string(), AmfValue::Boolean(false));
134        obj.insert("capabilities".to_string(), AmfValue::Number(15.0));
135        obj.insert("audioCodecs".to_string(), AmfValue::Number(3191.0));
136        obj.insert("videoCodecs".to_string(), AmfValue::Number(252.0));
137        obj.insert("videoFunction".to_string(), AmfValue::Number(1.0));
138
139        // Add E-RTMP fields if not in LegacyOnly mode
140        let client_caps = if !matches!(self.config.enhanced_rtmp, EnhancedRtmpMode::LegacyOnly) {
141            let caps = self.config.enhanced_capabilities.to_enhanced_capabilities();
142            self.add_ertmp_fields(&mut obj, &caps);
143            Some(caps)
144        } else {
145            None
146        };
147
148        let cmd = Command {
149            name: CMD_CONNECT.to_string(),
150            transaction_id: 1.0,
151            command_object: AmfValue::Object(obj),
152            arguments: vec![],
153            stream_id: 0,
154        };
155
156        self.send_command(&cmd).await?;
157
158        // Wait for connect result
159        loop {
160            let msg = self.read_message().await?;
161            match msg {
162                RtmpMessage::Command(cmd) if cmd.name == CMD_RESULT => {
163                    // Parse server's E-RTMP response
164                    self.handle_connect_result(&cmd, client_caps.as_ref())?;
165                    break;
166                }
167                RtmpMessage::Command(cmd) if cmd.name == CMD_ERROR => {
168                    return Err(Error::Rejected("Connect rejected".into()));
169                }
170                RtmpMessage::SetChunkSize(size) => {
171                    self.chunk_decoder.set_chunk_size(size);
172                }
173                RtmpMessage::WindowAckSize(_) | RtmpMessage::SetPeerBandwidth { .. } => {
174                    // Ignore these during connect
175                }
176                _ => {}
177            }
178        }
179
180        // Set our chunk size
181        self.chunk_encoder.set_chunk_size(RECOMMENDED_CHUNK_SIZE);
182        self.send_message(&RtmpMessage::SetChunkSize(RECOMMENDED_CHUNK_SIZE))
183            .await?;
184
185        Ok(())
186    }
187
188    /// Add E-RTMP fields to the connect command object.
189    fn add_ertmp_fields(&self, obj: &mut HashMap<String, AmfValue>, caps: &EnhancedCapabilities) {
190        // Add capsEx
191        obj.insert(
192            "capsEx".to_string(),
193            AmfValue::Number(caps.caps_ex.bits() as f64),
194        );
195
196        // Add fourCcList for compatibility (older E-RTMP implementations)
197        let mut fourcc_list: Vec<AmfValue> = Vec::new();
198
199        // Add video codecs to fourCcList
200        for codec in caps.video_codecs.keys() {
201            fourcc_list.push(AmfValue::String(codec.as_fourcc_str().to_string()));
202        }
203
204        // Add audio codecs to fourCcList
205        for codec in caps.audio_codecs.keys() {
206            fourcc_list.push(AmfValue::String(codec.as_fourcc_str().to_string()));
207        }
208
209        if !fourcc_list.is_empty() {
210            obj.insert("fourCcList".to_string(), AmfValue::Array(fourcc_list));
211        }
212
213        // Add videoFourCcInfoMap (modern E-RTMP)
214        if !caps.video_codecs.is_empty() {
215            let mut video_map: HashMap<String, AmfValue> = HashMap::new();
216            for (codec, capability) in &caps.video_codecs {
217                video_map.insert(
218                    codec.as_fourcc_str().to_string(),
219                    AmfValue::Number(capability.bits() as f64),
220                );
221            }
222            obj.insert(
223                "videoFourCcInfoMap".to_string(),
224                AmfValue::Object(video_map),
225            );
226        }
227
228        // Add audioFourCcInfoMap (modern E-RTMP)
229        if !caps.audio_codecs.is_empty() {
230            let mut audio_map: HashMap<String, AmfValue> = HashMap::new();
231            for (codec, capability) in &caps.audio_codecs {
232                audio_map.insert(
233                    codec.as_fourcc_str().to_string(),
234                    AmfValue::Number(capability.bits() as f64),
235                );
236            }
237            obj.insert(
238                "audioFourCcInfoMap".to_string(),
239                AmfValue::Object(audio_map),
240            );
241        }
242    }
243
244    /// Handle connect result and parse E-RTMP capabilities from server response.
245    fn handle_connect_result(
246        &mut self,
247        cmd: &Command,
248        client_caps: Option<&EnhancedCapabilities>,
249    ) -> Result<()> {
250        // Parse server's response for E-RTMP fields
251        let server_params = ConnectParams::from_amf(&cmd.command_object);
252
253        if server_params.has_enhanced_rtmp() {
254            // Server responded with E-RTMP capabilities
255            let server_caps = server_params.to_enhanced_capabilities();
256
257            // Intersect with our client caps to get final negotiated caps
258            let negotiated = if let Some(client) = client_caps {
259                client.intersect(&server_caps)
260            } else {
261                server_caps
262            };
263
264            if negotiated.enabled {
265                tracing::debug!(
266                    video_codecs = negotiated.video_codecs.len(),
267                    audio_codecs = negotiated.audio_codecs.len(),
268                    caps_ex = negotiated.caps_ex.bits(),
269                    "E-RTMP negotiated with server"
270                );
271                self.enhanced_capabilities = Some(negotiated);
272            }
273        } else if matches!(self.config.enhanced_rtmp, EnhancedRtmpMode::EnhancedOnly) {
274            // We required E-RTMP but server doesn't support it
275            return Err(Error::Rejected(
276                "Server does not support Enhanced RTMP".into(),
277            ));
278        }
279
280        Ok(())
281    }
282
283    /// Create a stream for publishing or playing
284    pub async fn create_stream(&mut self) -> Result<u32> {
285        let cmd = Command {
286            name: CMD_CREATE_STREAM.to_string(),
287            transaction_id: 2.0,
288            command_object: AmfValue::Null,
289            arguments: vec![],
290            stream_id: 0,
291        };
292
293        self.send_command(&cmd).await?;
294
295        // Wait for result
296        loop {
297            let msg = self.read_message().await?;
298            if let RtmpMessage::Command(result) = msg {
299                if result.name == CMD_RESULT && result.transaction_id == 2.0 {
300                    if let Some(id) = result.arguments.first().and_then(|v| v.as_number()) {
301                        self.stream_id = id as u32;
302                        return Ok(self.stream_id);
303                    }
304                }
305            }
306        }
307    }
308
309    /// Start playing a stream
310    pub async fn play(&mut self, stream_name: &str) -> Result<()> {
311        if self.stream_id == 0 {
312            self.create_stream().await?;
313        }
314
315        // Set buffer length
316        self.send_message(&RtmpMessage::UserControl(
317            crate::protocol::message::UserControlEvent::SetBufferLength {
318                stream_id: self.stream_id,
319                buffer_ms: self.config.buffer_length,
320            },
321        ))
322        .await?;
323
324        // Send play command
325        let cmd = Command {
326            name: CMD_PLAY.to_string(),
327            transaction_id: 0.0,
328            command_object: AmfValue::Null,
329            arguments: vec![
330                AmfValue::String(stream_name.to_string()),
331                AmfValue::Number(-2.0),  // Start: live or recorded
332                AmfValue::Number(-1.0),  // Duration: play until end
333                AmfValue::Boolean(true), // Reset
334            ],
335            stream_id: self.stream_id,
336        };
337
338        self.send_command(&cmd).await?;
339
340        // Wait for onStatus
341        loop {
342            let msg = self.read_message().await?;
343            if let RtmpMessage::Command(status) = msg {
344                if status.name == CMD_ON_STATUS {
345                    if let Some(info) = status.arguments.first().and_then(|v| v.as_object()) {
346                        if let Some(code) = info.get("code").and_then(|v| v.as_str()) {
347                            if code == NS_PLAY_START {
348                                return Ok(());
349                            } else if code.contains("Failed") || code.contains("Error") {
350                                return Err(Error::Rejected(code.to_string()));
351                            }
352                        }
353                    }
354                }
355            }
356        }
357    }
358
359    /// Read the next RTMP message
360    pub async fn read_message(&mut self) -> Result<RtmpMessage> {
361        loop {
362            // Try to decode from buffer
363            while let Some(chunk) = self.chunk_decoder.decode(&mut self.read_buf)? {
364                return RtmpMessage::from_chunk(&chunk);
365            }
366
367            // Need more data
368            let n = self.reader.read_buf(&mut self.read_buf).await?;
369            if n == 0 {
370                return Err(Error::ConnectionClosed);
371            }
372        }
373    }
374
375    /// Send an RTMP message
376    async fn send_message(&mut self, msg: &RtmpMessage) -> Result<()> {
377        let (msg_type, payload) = msg.encode();
378
379        let csid = match msg {
380            RtmpMessage::SetChunkSize(_)
381            | RtmpMessage::WindowAckSize(_)
382            | RtmpMessage::SetPeerBandwidth { .. }
383            | RtmpMessage::UserControl(_) => CSID_PROTOCOL_CONTROL,
384            RtmpMessage::Command(_) | RtmpMessage::CommandAmf3(_) => CSID_COMMAND,
385            _ => CSID_COMMAND,
386        };
387
388        let chunk = RtmpChunk {
389            csid,
390            timestamp: 0,
391            message_type: msg_type,
392            stream_id: 0,
393            payload,
394        };
395
396        self.write_buf.clear();
397        self.chunk_encoder.encode(&chunk, &mut self.write_buf);
398        self.writer.write_all(&self.write_buf).await?;
399        self.writer.flush().await?;
400
401        Ok(())
402    }
403
404    /// Send a command
405    async fn send_command(&mut self, cmd: &Command) -> Result<()> {
406        self.send_message(&RtmpMessage::Command(cmd.clone())).await
407    }
408
409    /// Get the stream ID
410    pub fn stream_id(&self) -> u32 {
411        self.stream_id
412    }
413
414    /// Check if E-RTMP is active for this connection
415    pub fn is_enhanced_rtmp(&self) -> bool {
416        self.enhanced_capabilities
417            .as_ref()
418            .map(|c| c.enabled)
419            .unwrap_or(false)
420    }
421
422    /// Get the negotiated E-RTMP capabilities
423    pub fn enhanced_capabilities(&self) -> Option<&EnhancedCapabilities> {
424        self.enhanced_capabilities.as_ref()
425    }
426}