1use std::collections::HashMap;
6
7use bytes::{Buf, Bytes, BytesMut};
8use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
9use tokio::net::TcpStream;
10use tokio::time::timeout;
11
12use crate::amf::AmfValue;
13use crate::error::{Error, Result};
14use crate::protocol::chunk::{ChunkDecoder, ChunkEncoder, RtmpChunk};
15use crate::protocol::constants::*;
16use crate::protocol::enhanced::{EnhancedCapabilities, EnhancedRtmpMode};
17use crate::protocol::handshake::{Handshake, HandshakeRole};
18use crate::protocol::message::{Command, ConnectParams, RtmpMessage};
19
20use super::config::{ClientConfig, ParsedUrl};
21
22pub struct RtmpConnector {
24 config: ClientConfig,
25 parsed_url: ParsedUrl,
26 reader: BufReader<tokio::io::ReadHalf<TcpStream>>,
27 writer: BufWriter<tokio::io::WriteHalf<TcpStream>>,
28 read_buf: BytesMut,
29 write_buf: BytesMut,
30 chunk_decoder: ChunkDecoder,
31 chunk_encoder: ChunkEncoder,
32 stream_id: u32,
33 enhanced_capabilities: Option<EnhancedCapabilities>,
35}
36
37impl RtmpConnector {
38 pub async fn connect(config: ClientConfig) -> Result<Self> {
40 let parsed_url = config
41 .parse_url()
42 .ok_or_else(|| Error::Config("Invalid RTMP URL".into()))?;
43
44 let addr = format!("{}:{}", parsed_url.host, parsed_url.port);
45
46 let socket = timeout(config.connect_timeout, TcpStream::connect(&addr))
47 .await
48 .map_err(|_| Error::Timeout)?
49 .map_err(Error::Io)?;
50
51 if config.tcp_nodelay {
52 socket.set_nodelay(true)?;
53 }
54
55 let (read_half, write_half) = tokio::io::split(socket);
56
57 let mut connector = Self {
58 config,
59 parsed_url,
60 reader: BufReader::with_capacity(64 * 1024, read_half),
61 writer: BufWriter::with_capacity(64 * 1024, write_half),
62 read_buf: BytesMut::with_capacity(64 * 1024),
63 write_buf: BytesMut::with_capacity(64 * 1024),
64 chunk_decoder: ChunkDecoder::new(),
65 chunk_encoder: ChunkEncoder::new(),
66 stream_id: 0,
67 enhanced_capabilities: None,
68 };
69
70 connector.do_handshake().await?;
71 connector.do_connect().await?;
72
73 Ok(connector)
74 }
75
76 async fn do_handshake(&mut self) -> Result<()> {
78 let mut handshake = Handshake::new(HandshakeRole::Client);
79
80 let c0c1 = handshake.generate_initial().ok_or(Error::Protocol(
82 crate::error::ProtocolError::InvalidChunkHeader,
83 ))?;
84 self.writer.write_all(&c0c1).await?;
85 self.writer.flush().await?;
86
87 let timeout_duration = self.config.connect_timeout;
89 timeout(timeout_duration, async {
90 loop {
91 let n = self.reader.read_buf(&mut self.read_buf).await?;
92 if n == 0 {
93 return Err(Error::ConnectionClosed);
94 }
95
96 let mut buf = Bytes::copy_from_slice(&self.read_buf);
97 if let Some(response) = handshake.process(&mut buf)? {
98 let consumed = self.read_buf.len() - buf.len();
99 self.read_buf.advance(consumed);
100
101 self.writer.write_all(&response).await?;
102 self.writer.flush().await?;
103 }
104
105 if handshake.is_done() {
106 break;
107 }
108 }
109 Ok::<_, Error>(())
110 })
111 .await
112 .map_err(|_| Error::Timeout)??;
113
114 Ok(())
115 }
116
117 async fn do_connect(&mut self) -> Result<()> {
119 let mut obj = HashMap::new();
120 obj.insert(
121 "app".to_string(),
122 AmfValue::String(self.parsed_url.app.clone()),
123 );
124 obj.insert("type".to_string(), AmfValue::String("nonprivate".into()));
125 obj.insert(
126 "flashVer".to_string(),
127 AmfValue::String(self.config.flash_ver.clone()),
128 );
129 obj.insert(
130 "tcUrl".to_string(),
131 AmfValue::String(self.config.url.clone()),
132 );
133 obj.insert("fpad".to_string(), AmfValue::Boolean(false));
134 obj.insert("capabilities".to_string(), AmfValue::Number(15.0));
135 obj.insert("audioCodecs".to_string(), AmfValue::Number(3191.0));
136 obj.insert("videoCodecs".to_string(), AmfValue::Number(252.0));
137 obj.insert("videoFunction".to_string(), AmfValue::Number(1.0));
138
139 let client_caps = if !matches!(self.config.enhanced_rtmp, EnhancedRtmpMode::LegacyOnly) {
141 let caps = self.config.enhanced_capabilities.to_enhanced_capabilities();
142 self.add_ertmp_fields(&mut obj, &caps);
143 Some(caps)
144 } else {
145 None
146 };
147
148 let cmd = Command {
149 name: CMD_CONNECT.to_string(),
150 transaction_id: 1.0,
151 command_object: AmfValue::Object(obj),
152 arguments: vec![],
153 stream_id: 0,
154 };
155
156 self.send_command(&cmd).await?;
157
158 loop {
160 let msg = self.read_message().await?;
161 match msg {
162 RtmpMessage::Command(cmd) if cmd.name == CMD_RESULT => {
163 self.handle_connect_result(&cmd, client_caps.as_ref())?;
165 break;
166 }
167 RtmpMessage::Command(cmd) if cmd.name == CMD_ERROR => {
168 return Err(Error::Rejected("Connect rejected".into()));
169 }
170 RtmpMessage::SetChunkSize(size) => {
171 self.chunk_decoder.set_chunk_size(size);
172 }
173 RtmpMessage::WindowAckSize(_) | RtmpMessage::SetPeerBandwidth { .. } => {
174 }
176 _ => {}
177 }
178 }
179
180 self.chunk_encoder.set_chunk_size(RECOMMENDED_CHUNK_SIZE);
182 self.send_message(&RtmpMessage::SetChunkSize(RECOMMENDED_CHUNK_SIZE))
183 .await?;
184
185 Ok(())
186 }
187
188 fn add_ertmp_fields(&self, obj: &mut HashMap<String, AmfValue>, caps: &EnhancedCapabilities) {
190 obj.insert(
192 "capsEx".to_string(),
193 AmfValue::Number(caps.caps_ex.bits() as f64),
194 );
195
196 let mut fourcc_list: Vec<AmfValue> = Vec::new();
198
199 for codec in caps.video_codecs.keys() {
201 fourcc_list.push(AmfValue::String(codec.as_fourcc_str().to_string()));
202 }
203
204 for codec in caps.audio_codecs.keys() {
206 fourcc_list.push(AmfValue::String(codec.as_fourcc_str().to_string()));
207 }
208
209 if !fourcc_list.is_empty() {
210 obj.insert("fourCcList".to_string(), AmfValue::Array(fourcc_list));
211 }
212
213 if !caps.video_codecs.is_empty() {
215 let mut video_map: HashMap<String, AmfValue> = HashMap::new();
216 for (codec, capability) in &caps.video_codecs {
217 video_map.insert(
218 codec.as_fourcc_str().to_string(),
219 AmfValue::Number(capability.bits() as f64),
220 );
221 }
222 obj.insert(
223 "videoFourCcInfoMap".to_string(),
224 AmfValue::Object(video_map),
225 );
226 }
227
228 if !caps.audio_codecs.is_empty() {
230 let mut audio_map: HashMap<String, AmfValue> = HashMap::new();
231 for (codec, capability) in &caps.audio_codecs {
232 audio_map.insert(
233 codec.as_fourcc_str().to_string(),
234 AmfValue::Number(capability.bits() as f64),
235 );
236 }
237 obj.insert(
238 "audioFourCcInfoMap".to_string(),
239 AmfValue::Object(audio_map),
240 );
241 }
242 }
243
244 fn handle_connect_result(
246 &mut self,
247 cmd: &Command,
248 client_caps: Option<&EnhancedCapabilities>,
249 ) -> Result<()> {
250 let server_params = ConnectParams::from_amf(&cmd.command_object);
252
253 if server_params.has_enhanced_rtmp() {
254 let server_caps = server_params.to_enhanced_capabilities();
256
257 let negotiated = if let Some(client) = client_caps {
259 client.intersect(&server_caps)
260 } else {
261 server_caps
262 };
263
264 if negotiated.enabled {
265 tracing::debug!(
266 video_codecs = negotiated.video_codecs.len(),
267 audio_codecs = negotiated.audio_codecs.len(),
268 caps_ex = negotiated.caps_ex.bits(),
269 "E-RTMP negotiated with server"
270 );
271 self.enhanced_capabilities = Some(negotiated);
272 }
273 } else if matches!(self.config.enhanced_rtmp, EnhancedRtmpMode::EnhancedOnly) {
274 return Err(Error::Rejected(
276 "Server does not support Enhanced RTMP".into(),
277 ));
278 }
279
280 Ok(())
281 }
282
283 pub async fn create_stream(&mut self) -> Result<u32> {
285 let cmd = Command {
286 name: CMD_CREATE_STREAM.to_string(),
287 transaction_id: 2.0,
288 command_object: AmfValue::Null,
289 arguments: vec![],
290 stream_id: 0,
291 };
292
293 self.send_command(&cmd).await?;
294
295 loop {
297 let msg = self.read_message().await?;
298 if let RtmpMessage::Command(result) = msg {
299 if result.name == CMD_RESULT && result.transaction_id == 2.0 {
300 if let Some(id) = result.arguments.first().and_then(|v| v.as_number()) {
301 self.stream_id = id as u32;
302 return Ok(self.stream_id);
303 }
304 }
305 }
306 }
307 }
308
309 pub async fn play(&mut self, stream_name: &str) -> Result<()> {
311 if self.stream_id == 0 {
312 self.create_stream().await?;
313 }
314
315 self.send_message(&RtmpMessage::UserControl(
317 crate::protocol::message::UserControlEvent::SetBufferLength {
318 stream_id: self.stream_id,
319 buffer_ms: self.config.buffer_length,
320 },
321 ))
322 .await?;
323
324 let cmd = Command {
326 name: CMD_PLAY.to_string(),
327 transaction_id: 0.0,
328 command_object: AmfValue::Null,
329 arguments: vec![
330 AmfValue::String(stream_name.to_string()),
331 AmfValue::Number(-2.0), AmfValue::Number(-1.0), AmfValue::Boolean(true), ],
335 stream_id: self.stream_id,
336 };
337
338 self.send_command(&cmd).await?;
339
340 loop {
342 let msg = self.read_message().await?;
343 if let RtmpMessage::Command(status) = msg {
344 if status.name == CMD_ON_STATUS {
345 if let Some(info) = status.arguments.first().and_then(|v| v.as_object()) {
346 if let Some(code) = info.get("code").and_then(|v| v.as_str()) {
347 if code == NS_PLAY_START {
348 return Ok(());
349 } else if code.contains("Failed") || code.contains("Error") {
350 return Err(Error::Rejected(code.to_string()));
351 }
352 }
353 }
354 }
355 }
356 }
357 }
358
359 pub async fn read_message(&mut self) -> Result<RtmpMessage> {
361 loop {
362 while let Some(chunk) = self.chunk_decoder.decode(&mut self.read_buf)? {
364 return RtmpMessage::from_chunk(&chunk);
365 }
366
367 let n = self.reader.read_buf(&mut self.read_buf).await?;
369 if n == 0 {
370 return Err(Error::ConnectionClosed);
371 }
372 }
373 }
374
375 async fn send_message(&mut self, msg: &RtmpMessage) -> Result<()> {
377 let (msg_type, payload) = msg.encode();
378
379 let csid = match msg {
380 RtmpMessage::SetChunkSize(_)
381 | RtmpMessage::WindowAckSize(_)
382 | RtmpMessage::SetPeerBandwidth { .. }
383 | RtmpMessage::UserControl(_) => CSID_PROTOCOL_CONTROL,
384 RtmpMessage::Command(_) | RtmpMessage::CommandAmf3(_) => CSID_COMMAND,
385 _ => CSID_COMMAND,
386 };
387
388 let chunk = RtmpChunk {
389 csid,
390 timestamp: 0,
391 message_type: msg_type,
392 stream_id: 0,
393 payload,
394 };
395
396 self.write_buf.clear();
397 self.chunk_encoder.encode(&chunk, &mut self.write_buf);
398 self.writer.write_all(&self.write_buf).await?;
399 self.writer.flush().await?;
400
401 Ok(())
402 }
403
404 async fn send_command(&mut self, cmd: &Command) -> Result<()> {
406 self.send_message(&RtmpMessage::Command(cmd.clone())).await
407 }
408
409 pub fn stream_id(&self) -> u32 {
411 self.stream_id
412 }
413
414 pub fn is_enhanced_rtmp(&self) -> bool {
416 self.enhanced_capabilities
417 .as_ref()
418 .map(|c| c.enabled)
419 .unwrap_or(false)
420 }
421
422 pub fn enhanced_capabilities(&self) -> Option<&EnhancedCapabilities> {
424 self.enhanced_capabilities.as_ref()
425 }
426}