1use crate::codec::{CodecPipeline, CodecType, CodecError};
2use crate::jitter::JitterBuffer;
3use crate::packet::RtpPacket;
4use std::collections::VecDeque;
5use std::net::SocketAddr;
6use std::sync::Arc;
7use thiserror::Error;
8use tokio::net::UdpSocket;
9use tokio::sync::mpsc;
10
11#[derive(Debug, Error)]
12pub enum SessionError {
13 #[error("IO error: {0}")]
14 Io(#[from] std::io::Error),
15 #[error("codec error: {0}")]
16 Codec(#[from] CodecError),
17 #[error("RTP error: {0}")]
18 Rtp(#[from] crate::packet::RtpError),
19 #[error("session not started")]
20 NotStarted,
21 #[error("invalid DTMF digit: {0}")]
22 InvalidDtmfDigit(char),
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct DtmfEvent {
27 pub digit: char,
28 pub end: bool,
29 pub duration: u16,
30 pub volume: u8,
31 pub sequence_number: u16,
32 pub timestamp: u32,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum ReceiveEvent {
37 Audio(Vec<i16>),
38 Dtmf(DtmfEvent),
39}
40
41#[derive(Debug, Clone)]
42struct QueuedDtmf {
43 digit: char,
44 duration_samples: u16,
45}
46
47#[derive(Debug, Clone)]
49pub struct SessionConfig {
50 pub local_addr: String,
51 pub remote_addr: SocketAddr,
52 pub codec: CodecType,
53 pub ssrc: u32,
54 pub jitter_buffer_size: usize,
55}
56
57impl SessionConfig {
58 pub fn new(local_addr: &str, remote_addr: SocketAddr, codec: CodecType) -> Self {
59 Self {
60 local_addr: local_addr.to_string(),
61 remote_addr,
62 codec,
63 ssrc: rand::random(),
64 jitter_buffer_size: 10,
65 }
66 }
67}
68
69pub struct RtpSession {
71 socket: Arc<UdpSocket>,
72 config: SessionConfig,
73 codec: CodecPipeline,
74 sequence_number: u16,
75 timestamp: u32,
76 local_addr: SocketAddr,
77 dtmf_queue: VecDeque<QueuedDtmf>,
78}
79
80impl RtpSession {
81 pub async fn new(config: SessionConfig) -> Result<Self, SessionError> {
83 let socket = UdpSocket::bind(&config.local_addr).await?;
84 let local_addr = socket.local_addr()?;
85 let codec = CodecPipeline::new(config.codec);
86
87 Ok(Self {
88 socket: Arc::new(socket),
89 config,
90 codec,
91 sequence_number: 0,
92 timestamp: 0,
93 local_addr,
94 dtmf_queue: VecDeque::new(),
95 })
96 }
97
98 pub fn local_addr(&self) -> SocketAddr {
100 self.local_addr
101 }
102
103 pub fn set_remote_addr(&mut self, addr: SocketAddr) {
105 self.config.remote_addr = addr;
106 }
107
108 pub async fn send_audio(&mut self, pcm_samples: &[i16]) -> Result<usize, SessionError> {
110 let encoded = self.codec.encode(pcm_samples)?;
111
112 let packet = RtpPacket::new(
113 self.config.codec.payload_type(),
114 self.sequence_number,
115 self.timestamp,
116 self.config.ssrc,
117 )
118 .with_payload(encoded);
119
120 let data = packet.serialize();
121 let sent = self.socket.send_to(&data, self.config.remote_addr).await?;
122
123 self.sequence_number = self.sequence_number.wrapping_add(1);
124 self.timestamp = self
125 .timestamp
126 .wrapping_add(self.config.codec.samples_per_frame() as u32);
127
128 Ok(sent)
129 }
130
131 pub async fn send_rfc2833_digit(
135 &mut self,
136 digit: char,
137 payload_type: u8,
138 ) -> Result<(), SessionError> {
139 self.send_rfc2833_digit_with_duration(digit, payload_type, 800).await
140 }
141
142 async fn send_rfc2833_digit_with_duration(
143 &mut self,
144 digit: char,
145 payload_type: u8,
146 duration_samples: u16,
147 ) -> Result<(), SessionError> {
148 let event = dtmf_digit_to_event(digit).ok_or(SessionError::InvalidDtmfDigit(digit))?;
149 let start_ts = self.timestamp;
150 let ramps = [160u16, 320u16, duration_samples];
151 for (idx, dur) in ramps.iter().enumerate() {
152 let end = idx == ramps.len() - 1;
153 let marker = idx == 0;
154 let payload = vec![
155 event,
156 ((end as u8) << 7) | 10u8, (dur >> 8) as u8,
158 (*dur & 0xFF) as u8,
159 ];
160 let packet = RtpPacket::new(payload_type, self.sequence_number, start_ts, self.config.ssrc)
161 .with_marker(marker)
162 .with_payload(payload);
163 let data = packet.serialize();
164 self.socket.send_to(&data, self.config.remote_addr).await?;
165 self.sequence_number = self.sequence_number.wrapping_add(1);
166 }
167 self.timestamp = self.timestamp.wrapping_add(duration_samples as u32);
169 Ok(())
170 }
171
172 pub fn queue_rfc2833_digits(&mut self, digits: &str) -> Result<usize, SessionError> {
176 let mut queued = 0usize;
177 for ch in digits.chars().filter(|c| !c.is_whitespace()) {
178 validate_dtmf_digit(ch)?;
179 self.dtmf_queue.push_back(QueuedDtmf {
180 digit: ch.to_ascii_uppercase(),
181 duration_samples: 800,
182 });
183 queued += 1;
184 }
185 Ok(queued)
186 }
187
188 pub fn queued_rfc2833_digits(&self) -> usize {
189 self.dtmf_queue.len()
190 }
191
192 pub async fn send_next_queued_rfc2833(
194 &mut self,
195 payload_type: u8,
196 ) -> Result<Option<char>, SessionError> {
197 let Some(next) = self.dtmf_queue.pop_front() else {
198 return Ok(None);
199 };
200 self.send_rfc2833_digit_with_duration(next.digit, payload_type, next.duration_samples)
201 .await?;
202 Ok(Some(next.digit))
203 }
204
205 pub async fn flush_queued_rfc2833(
207 &mut self,
208 payload_type: u8,
209 inter_digit_gap_ms: u64,
210 ) -> Result<usize, SessionError> {
211 let mut sent = 0usize;
212 while let Some(_digit) = self.send_next_queued_rfc2833(payload_type).await? {
213 sent += 1;
214 if !self.dtmf_queue.is_empty() && inter_digit_gap_ms > 0 {
215 tokio::time::sleep(std::time::Duration::from_millis(inter_digit_gap_ms)).await;
216 }
217 }
218 Ok(sent)
219 }
220
221 pub async fn send_packet(&self, packet: &RtpPacket) -> Result<usize, SessionError> {
223 let data = packet.serialize();
224 let sent = self.socket.send_to(&data, self.config.remote_addr).await?;
225 Ok(sent)
226 }
227
228 pub async fn recv_packet(&self) -> Result<(RtpPacket, SocketAddr), SessionError> {
230 let mut buf = vec![0u8; 65535];
231 let (len, source) = self.socket.recv_from(&mut buf).await?;
232 let packet = RtpPacket::parse(&buf[..len])?;
233 Ok((packet, source))
234 }
235
236 pub fn decode_packet(&mut self, packet: &RtpPacket) -> Result<Vec<i16>, SessionError> {
238 Ok(self.codec.decode(&packet.payload)?)
239 }
240
241 pub fn silence_frame(&self) -> Vec<u8> {
243 self.codec.silence_frame()
244 }
245
246 pub fn start_receiving(
248 &self,
249 buffer_size: usize,
250 ) -> (
251 mpsc::Receiver<Vec<i16>>,
252 mpsc::Sender<()>,
253 ) {
254 let (audio_tx, audio_rx) = mpsc::channel(buffer_size);
255 let (stop_tx, mut stop_rx) = mpsc::channel::<()>(1);
256 let socket = self.socket.clone();
257 let codec_type = self.config.codec;
258 let jitter_size = self.config.jitter_buffer_size;
259
260 tokio::spawn(async move {
261 let mut codec = CodecPipeline::new(codec_type);
262 let mut jitter = JitterBuffer::new(jitter_size);
263 let mut buf = vec![0u8; 65535];
264
265 loop {
266 tokio::select! {
267 result = socket.recv_from(&mut buf) => {
268 match result {
269 Ok((len, _source)) => {
270 match RtpPacket::parse(&buf[..len]) {
271 Ok(packet) => {
272 if packet.payload_type >= 64 && packet.payload_type <= 95 {
277 continue;
278 }
279 if packet.payload_type != codec_type.payload_type() {
282 continue;
283 }
284
285 jitter.insert(packet);
286
287 if let Some(pkt) = jitter.pop() {
290 match codec.decode(&pkt.payload) {
291 Ok(samples) => {
292 match audio_tx.try_send(samples) {
294 Ok(_) => {}
295 Err(mpsc::error::TrySendError::Full(_)) => {
296 }
298 Err(mpsc::error::TrySendError::Closed(_)) => {
299 return; }
301 }
302 }
303 Err(_e) => {
304 tracing::warn!("RTP decode error: {}", _e);
305 }
306 }
307 }
308 }
309 Err(_) => {
310 }
312 }
313 }
314 Err(e) => {
315 tracing::error!("RTP receive error: {}", e);
316 break;
317 }
318 }
319 }
320 _ = stop_rx.recv() => {
321 break;
322 }
323 }
324 }
325 });
326
327 (audio_rx, stop_tx)
328 }
329
330 pub fn start_receiving_events(
332 &self,
333 buffer_size: usize,
334 dtmf_payload_type: Option<u8>,
335 ) -> (
336 mpsc::Receiver<ReceiveEvent>,
337 mpsc::Sender<()>,
338 ) {
339 let (event_tx, event_rx) = mpsc::channel(buffer_size);
340 let (stop_tx, mut stop_rx) = mpsc::channel::<()>(1);
341 let socket = self.socket.clone();
342 let codec_type = self.config.codec;
343 let jitter_size = self.config.jitter_buffer_size;
344
345 tokio::spawn(async move {
346 let mut codec = CodecPipeline::new(codec_type);
347 let mut jitter = JitterBuffer::new(jitter_size);
348 let mut buf = vec![0u8; 65535];
349
350 loop {
351 tokio::select! {
352 result = socket.recv_from(&mut buf) => {
353 match result {
354 Ok((len, _source)) => {
355 match RtpPacket::parse(&buf[..len]) {
356 Ok(packet) => {
357 if packet.payload_type >= 64 && packet.payload_type <= 95 {
359 continue;
360 }
361
362 if let Some(pt) = dtmf_payload_type {
363 if packet.payload_type == pt {
364 if let Some(dtmf) = parse_rfc2833_event(&packet) {
365 match event_tx.try_send(ReceiveEvent::Dtmf(dtmf)) {
366 Ok(_) => {}
367 Err(mpsc::error::TrySendError::Full(_)) => {}
368 Err(mpsc::error::TrySendError::Closed(_)) => return,
369 }
370 }
371 continue;
372 }
373 }
374
375 jitter.insert(packet);
376 if let Some(pkt) = jitter.pop() {
377 match codec.decode(&pkt.payload) {
378 Ok(samples) => {
379 match event_tx.try_send(ReceiveEvent::Audio(samples)) {
380 Ok(_) => {}
381 Err(mpsc::error::TrySendError::Full(_)) => {}
382 Err(mpsc::error::TrySendError::Closed(_)) => return,
383 }
384 }
385 Err(_e) => {
386 tracing::warn!("RTP decode error: {}", _e);
387 }
388 }
389 }
390 }
391 Err(_) => {}
392 }
393 }
394 Err(e) => {
395 tracing::error!("RTP receive error: {}", e);
396 break;
397 }
398 }
399 }
400 _ = stop_rx.recv() => {
401 break;
402 }
403 }
404 }
405 });
406
407 (event_rx, stop_tx)
408 }
409
410 pub fn stats(&self) -> SessionStats {
412 SessionStats {
413 local_addr: self.local_addr,
414 remote_addr: self.config.remote_addr,
415 codec: self.config.codec,
416 ssrc: self.config.ssrc,
417 packets_sent: self.sequence_number as u64,
418 }
419 }
420
421 pub fn codec(&self) -> &CodecPipeline {
423 &self.codec
424 }
425}
426
427fn validate_dtmf_digit(digit: char) -> Result<(), SessionError> {
428 if dtmf_digit_to_event(digit).is_some() {
429 Ok(())
430 } else {
431 Err(SessionError::InvalidDtmfDigit(digit))
432 }
433}
434
435fn dtmf_digit_to_event(digit: char) -> Option<u8> {
436 match digit.to_ascii_uppercase() {
437 '0' => Some(0),
438 '1' => Some(1),
439 '2' => Some(2),
440 '3' => Some(3),
441 '4' => Some(4),
442 '5' => Some(5),
443 '6' => Some(6),
444 '7' => Some(7),
445 '8' => Some(8),
446 '9' => Some(9),
447 '*' => Some(10),
448 '#' => Some(11),
449 'A' => Some(12),
450 'B' => Some(13),
451 'C' => Some(14),
452 'D' => Some(15),
453 _ => None,
454 }
455}
456
457fn dtmf_event_to_digit(event: u8) -> Option<char> {
458 match event {
459 0..=9 => Some((b'0' + event) as char),
460 10 => Some('*'),
461 11 => Some('#'),
462 12 => Some('A'),
463 13 => Some('B'),
464 14 => Some('C'),
465 15 => Some('D'),
466 _ => None,
467 }
468}
469
470fn parse_rfc2833_event(packet: &RtpPacket) -> Option<DtmfEvent> {
471 if packet.payload.len() < 4 {
472 return None;
473 }
474 let event = packet.payload[0];
475 let e_r_volume = packet.payload[1];
476 let end = (e_r_volume & 0x80) != 0;
477 let volume = e_r_volume & 0x3F;
478 let duration = u16::from_be_bytes([packet.payload[2], packet.payload[3]]);
479 let digit = dtmf_event_to_digit(event)?;
480 Some(DtmfEvent {
481 digit,
482 end,
483 duration,
484 volume,
485 sequence_number: packet.sequence_number,
486 timestamp: packet.timestamp,
487 })
488}
489
490#[derive(Debug, Clone)]
491pub struct SessionStats {
492 pub local_addr: SocketAddr,
493 pub remote_addr: SocketAddr,
494 pub codec: CodecType,
495 pub ssrc: u32,
496 pub packets_sent: u64,
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use std::net::{IpAddr, Ipv4Addr};
503
504 #[tokio::test]
505 async fn test_session_creation() {
506 let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9999);
507 let config = SessionConfig::new("127.0.0.1:0", remote_addr, CodecType::Pcmu);
508 let session = RtpSession::new(config).await.unwrap();
509
510 let addr = session.local_addr();
511 assert_eq!(addr.ip(), IpAddr::V4(Ipv4Addr::LOCALHOST));
512 assert!(addr.port() > 0);
513 }
514
515 #[tokio::test]
516 async fn test_send_and_receive_audio() {
517 let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
518
519 let recv_config = SessionConfig::new("127.0.0.1:0", remote_addr, CodecType::Pcmu);
521 let mut recv_session = RtpSession::new(recv_config).await.unwrap();
522 let recv_addr = recv_session.local_addr();
523
524 let send_config = SessionConfig::new("127.0.0.1:0", recv_addr, CodecType::Pcmu);
526 let mut send_session = RtpSession::new(send_config).await.unwrap();
527
528 let samples: Vec<i16> = (0..160)
530 .map(|i| ((i as f64 / 160.0 * std::f64::consts::TAU).sin() * 8000.0) as i16)
531 .collect();
532
533 let sent = send_session.send_audio(&samples).await.unwrap();
534 assert!(sent > 0);
535
536 let (packet, _source) = recv_session.recv_packet().await.unwrap();
538 assert_eq!(packet.payload_type, 0); assert_eq!(packet.sequence_number, 0);
540 assert_eq!(packet.payload.len(), 160);
541
542 let decoded = recv_session.decode_packet(&packet).unwrap();
544 assert_eq!(decoded.len(), 160);
545 }
546
547 #[tokio::test]
548 async fn test_send_multiple_packets() {
549 let recv_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
550 let recv_addr = recv_socket.local_addr().unwrap();
551
552 let send_config = SessionConfig::new("127.0.0.1:0", recv_addr, CodecType::Pcmu);
553 let mut send_session = RtpSession::new(send_config).await.unwrap();
554
555 for _ in 0..3 {
557 let samples = vec![0i16; 160];
558 send_session.send_audio(&samples).await.unwrap();
559 }
560
561 let stats = send_session.stats();
562 assert_eq!(stats.packets_sent, 3);
563 }
564
565 #[tokio::test]
566 async fn test_session_stats() {
567 let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9999);
568 let config = SessionConfig::new("127.0.0.1:0", remote_addr, CodecType::Pcmu);
569 let session = RtpSession::new(config).await.unwrap();
570
571 let stats = session.stats();
572 assert_eq!(stats.codec, CodecType::Pcmu);
573 assert_eq!(stats.remote_addr, remote_addr);
574 assert_eq!(stats.packets_sent, 0);
575 }
576
577 #[tokio::test]
578 async fn test_silence_frame() {
579 let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9999);
580 let config = SessionConfig::new("127.0.0.1:0", remote_addr, CodecType::Pcmu);
581 let session = RtpSession::new(config).await.unwrap();
582
583 let silence = session.silence_frame();
584 assert_eq!(silence.len(), 160);
585 }
586
587 #[tokio::test]
588 async fn test_receive_loop() {
589 let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
590
591 let recv_config = SessionConfig::new("127.0.0.1:0", remote_addr, CodecType::Pcmu);
592 let recv_session = RtpSession::new(recv_config).await.unwrap();
593 let recv_addr = recv_session.local_addr();
594
595 let (mut audio_rx, stop_tx) = recv_session.start_receiving(16);
596
597 let send_config = SessionConfig::new("127.0.0.1:0", recv_addr, CodecType::Pcmu);
599 let mut send_session = RtpSession::new(send_config).await.unwrap();
600
601 for _ in 0..3 {
603 let samples = vec![1000i16; 160];
604 send_session.send_audio(&samples).await.unwrap();
605 }
606
607 let audio = tokio::time::timeout(
609 std::time::Duration::from_secs(2),
610 audio_rx.recv(),
611 )
612 .await
613 .unwrap()
614 .unwrap();
615
616 assert_eq!(audio.len(), 160);
617
618 let _ = stop_tx.send(()).await;
620 }
621
622 #[tokio::test]
623 async fn test_pcma_session() {
624 let recv_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
625 let recv_addr = recv_socket.local_addr().unwrap();
626
627 let send_config = SessionConfig::new("127.0.0.1:0", recv_addr, CodecType::Pcma);
628 let mut send_session = RtpSession::new(send_config).await.unwrap();
629
630 let samples = vec![5000i16; 160];
631 let sent = send_session.send_audio(&samples).await.unwrap();
632 assert!(sent > 0);
633
634 let mut buf = vec![0u8; 65535];
636 let (len, _) = recv_socket.recv_from(&mut buf).await.unwrap();
637 let packet = RtpPacket::parse(&buf[..len]).unwrap();
638 assert_eq!(packet.payload_type, 8); }
640
641 #[tokio::test]
642 async fn test_send_raw_packet() {
643 let recv_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
644 let recv_addr = recv_socket.local_addr().unwrap();
645
646 let send_config = SessionConfig::new("127.0.0.1:0", recv_addr, CodecType::Pcmu);
647 let send_session = RtpSession::new(send_config).await.unwrap();
648
649 let packet = RtpPacket::new(0, 42, 6720, 0xBEEF)
650 .with_payload(vec![0x7F; 160]);
651
652 let sent = send_session.send_packet(&packet).await.unwrap();
653 assert!(sent > 0);
654
655 let mut buf = vec![0u8; 65535];
656 let (len, _) = recv_socket.recv_from(&mut buf).await.unwrap();
657 let received = RtpPacket::parse(&buf[..len]).unwrap();
658 assert_eq!(received.sequence_number, 42);
659 assert_eq!(received.ssrc, 0xBEEF);
660 }
661
662 #[tokio::test]
663 async fn test_send_rfc2833_digit_packet_shape() {
664 let recv_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
665 let recv_addr = recv_socket.local_addr().unwrap();
666 let config = SessionConfig::new("127.0.0.1:0", recv_addr, CodecType::Pcmu);
667 let mut sender = RtpSession::new(config).await.unwrap();
668
669 sender.send_rfc2833_digit('5', 101).await.unwrap();
670
671 let mut buf = vec![0u8; 65535];
672 let (len, _) = recv_socket.recv_from(&mut buf).await.unwrap();
673 let pkt = RtpPacket::parse(&buf[..len]).unwrap();
674 assert_eq!(pkt.payload_type, 101);
675 assert!(pkt.marker);
676 assert_eq!(pkt.payload[0], 5);
677 }
678
679 #[tokio::test]
680 async fn test_queue_and_flush_rfc2833_digits() {
681 let recv_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
682 let recv_addr = recv_socket.local_addr().unwrap();
683 let config = SessionConfig::new("127.0.0.1:0", recv_addr, CodecType::Pcmu);
684 let mut sender = RtpSession::new(config).await.unwrap();
685
686 let queued = sender.queue_rfc2833_digits("12#").unwrap();
687 assert_eq!(queued, 3);
688 assert_eq!(sender.queued_rfc2833_digits(), 3);
689
690 let sent = sender.flush_queued_rfc2833(101, 0).await.unwrap();
691 assert_eq!(sent, 3);
692 assert_eq!(sender.queued_rfc2833_digits(), 0);
693 }
694
695 #[tokio::test]
696 async fn test_receive_event_reports_dtmf() {
697 let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0);
698 let recv_config = SessionConfig::new("127.0.0.1:0", remote_addr, CodecType::Pcmu);
699 let recv_session = RtpSession::new(recv_config).await.unwrap();
700 let recv_addr = recv_session.local_addr();
701 let (mut events, stop_tx) = recv_session.start_receiving_events(16, Some(101));
702
703 let send_config = SessionConfig::new("127.0.0.1:0", recv_addr, CodecType::Pcmu);
704 let mut send_session = RtpSession::new(send_config).await.unwrap();
705 send_session.send_rfc2833_digit('9', 101).await.unwrap();
706
707 let evt = tokio::time::timeout(std::time::Duration::from_secs(2), events.recv())
708 .await
709 .unwrap()
710 .unwrap();
711 match evt {
712 ReceiveEvent::Dtmf(dtmf) => assert_eq!(dtmf.digit, '9'),
713 ReceiveEvent::Audio(_) => panic!("expected DTMF event"),
714 }
715
716 let _ = stop_tx.send(()).await;
717 }
718}