rama_proxy/proxydb/
csv.rs

1use super::{Proxy, StringFilter};
2use base64::Engine;
3use base64::engine::general_purpose::STANDARD as ENGINE;
4use rama_net::{
5    address::ProxyAddress,
6    asn::{Asn, InvalidAsn},
7    user::ProxyCredential,
8};
9use std::path::Path;
10use tokio::{
11    fs::File,
12    io::{AsyncBufReadExt, BufReader, Lines},
13};
14
15#[derive(Debug)]
16/// A CSV Reader that can be used to create a [`Proxy`] database from a CSV file or raw data.
17pub struct ProxyCsvRowReader {
18    data: ProxyCsvRowReaderData,
19}
20
21impl ProxyCsvRowReader {
22    /// Create a new [`ProxyCsvRowReader`] from the given CSV file.
23    pub async fn open(path: impl AsRef<Path>) -> Result<Self, ProxyCsvRowReaderError> {
24        let file = tokio::fs::File::open(path).await?;
25        let reader = BufReader::new(file);
26        let lines = reader.lines();
27        Ok(Self {
28            data: ProxyCsvRowReaderData::File(lines),
29        })
30    }
31
32    /// Create a new [`ProxyCsvRowReader`] from the given CSV data.
33    pub fn raw(data: impl AsRef<str>) -> Self {
34        let lines: Vec<_> = data.as_ref().lines().rev().map(str::to_owned).collect();
35        Self {
36            data: ProxyCsvRowReaderData::Raw(lines),
37        }
38    }
39
40    /// Read the next row from the CSV file.
41    pub async fn next(&mut self) -> Result<Option<Proxy>, ProxyCsvRowReaderError> {
42        match &mut self.data {
43            ProxyCsvRowReaderData::File(lines) => {
44                let line = lines.next_line().await?;
45                match line {
46                    Some(line) => Ok(Some(match parse_csv_row(&line) {
47                        Some(proxy) => proxy,
48                        None => {
49                            return Err(ProxyCsvRowReaderError {
50                                kind: ProxyCsvRowReaderErrorKind::InvalidRow(line),
51                            });
52                        }
53                    })),
54                    None => Ok(None),
55                }
56            }
57            ProxyCsvRowReaderData::Raw(lines) => match lines.pop() {
58                Some(line) => Ok(Some(match parse_csv_row(&line) {
59                    Some(proxy) => proxy,
60                    None => {
61                        return Err(ProxyCsvRowReaderError {
62                            kind: ProxyCsvRowReaderErrorKind::InvalidRow(line),
63                        });
64                    }
65                })),
66                None => Ok(None),
67            },
68        }
69    }
70}
71
72fn strip_csv_quotes(p: &str) -> &str {
73    p.strip_prefix('"')
74        .and_then(|p| p.strip_suffix('"'))
75        .unwrap_or(p)
76}
77
78pub(crate) fn parse_csv_row(row: &str) -> Option<Proxy> {
79    let mut iter = row.split(',').map(strip_csv_quotes);
80
81    let id = iter.next().and_then(|s| s.try_into().ok())?;
82
83    let tcp = iter.next().and_then(parse_csv_bool)?;
84    let udp = iter.next().and_then(parse_csv_bool)?;
85    let http = iter.next().and_then(parse_csv_bool)?;
86    let https = iter.next().and_then(parse_csv_bool)?;
87    let socks5 = iter.next().and_then(parse_csv_bool)?;
88    let socks5h = iter.next().and_then(parse_csv_bool)?;
89    let datacenter = iter.next().and_then(parse_csv_bool)?;
90    let residential = iter.next().and_then(parse_csv_bool)?;
91    let mobile = iter.next().and_then(parse_csv_bool)?;
92    let mut address = iter.next().and_then(|s| {
93        if s.is_empty() {
94            None
95        } else {
96            ProxyAddress::try_from(s).ok()
97        }
98    })?;
99    let pool_id = parse_csv_opt_string_filter(iter.next()?);
100    let continent = parse_csv_opt_string_filter(iter.next()?);
101    let country = parse_csv_opt_string_filter(iter.next()?);
102    let state = parse_csv_opt_string_filter(iter.next()?);
103    let city = parse_csv_opt_string_filter(iter.next()?);
104    let carrier = parse_csv_opt_string_filter(iter.next()?);
105    let asn = parse_csv_opt_asn(iter.next()?).ok()?;
106
107    // support header format or cleartext format
108    if let Some(value) = iter.next()
109        && !value.is_empty()
110    {
111        address.credential = Some(match value.split_once(' ') {
112            Some((t, v)) => {
113                if t.eq_ignore_ascii_case("basic") {
114                    let bytes = ENGINE.decode(v).ok()?;
115                    let decoded = String::from_utf8(bytes).ok()?;
116                    ProxyCredential::Basic(decoded.parse().ok()?)
117                } else if t.eq_ignore_ascii_case("bearer") {
118                    ProxyCredential::Bearer(v.parse().ok()?)
119                } else {
120                    ProxyCredential::Basic(value.parse().ok()?)
121                }
122            }
123            None => ProxyCredential::Basic(value.parse().ok()?),
124        });
125    }
126
127    // Ensure there are no more values in the row
128    if iter.next().is_some() {
129        return None;
130    }
131
132    Some(Proxy {
133        id,
134        address,
135        tcp,
136        udp,
137        http,
138        https,
139        socks5,
140        socks5h,
141        datacenter,
142        residential,
143        mobile,
144        pool_id,
145        continent,
146        country,
147        state,
148        city,
149        carrier,
150        asn,
151    })
152}
153
154fn parse_csv_bool(value: &str) -> Option<bool> {
155    rama_utils::macros::match_ignore_ascii_case_str! {
156        match(value) {
157            "true" | "1" => Some(true),
158            "" | "false" | "0" | "null" | "nil" => Some(false),
159            _ => None,
160        }
161    }
162}
163
164fn parse_csv_opt_string_filter(value: &str) -> Option<StringFilter> {
165    if value.is_empty() {
166        None
167    } else {
168        Some(StringFilter::from(value))
169    }
170}
171
172fn parse_csv_opt_asn(value: &str) -> Result<Option<Asn>, InvalidAsn> {
173    if value.is_empty() {
174        Ok(None)
175    } else {
176        Asn::try_from(value).map(Some)
177    }
178}
179
180#[derive(Debug)]
181enum ProxyCsvRowReaderData {
182    File(Lines<BufReader<File>>),
183    Raw(Vec<String>),
184}
185
186#[derive(Debug)]
187/// An error that can occur when reading a Proxy CSV row.
188pub struct ProxyCsvRowReaderError {
189    kind: ProxyCsvRowReaderErrorKind,
190}
191
192#[derive(Debug)]
193/// The kind of error that can occur when reading a Proxy CSV row.
194pub enum ProxyCsvRowReaderErrorKind {
195    /// An I/O error occurred while reading the CSV row.
196    IoError(std::io::Error),
197    /// The CSV row is invalid, and could not be parsed.
198    InvalidRow(String),
199}
200
201impl std::fmt::Display for ProxyCsvRowReaderError {
202    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203        match &self.kind {
204            ProxyCsvRowReaderErrorKind::IoError(err) => write!(f, "I/O error: {err}"),
205            ProxyCsvRowReaderErrorKind::InvalidRow(row) => write!(f, "Invalid row: {row}"),
206        }
207    }
208}
209
210impl std::error::Error for ProxyCsvRowReaderError {
211    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
212        match &self.kind {
213            ProxyCsvRowReaderErrorKind::IoError(err) => Some(err),
214            ProxyCsvRowReaderErrorKind::InvalidRow(_) => None,
215        }
216    }
217}
218
219impl From<std::io::Error> for ProxyCsvRowReaderError {
220    fn from(err: std::io::Error) -> Self {
221        Self {
222            kind: ProxyCsvRowReaderErrorKind::IoError(err),
223        }
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::{ProxyFilter, proxydb::ProxyContext};
231    use rama_net::transport::TransportProtocol;
232    use rama_utils::str::NonEmptyString;
233    use std::str::FromStr;
234
235    #[test]
236    fn test_parse_csv_bool() {
237        for (input, output) in &[
238            ("1", Some(true)),
239            ("true", Some(true)),
240            ("True", Some(true)),
241            ("TRUE", Some(true)),
242            ("0", Some(false)),
243            ("false", Some(false)),
244            ("False", Some(false)),
245            ("FALSE", Some(false)),
246            ("null", Some(false)),
247            ("nil", Some(false)),
248            ("NULL", Some(false)),
249            ("NIL", Some(false)),
250            ("", Some(false)),
251            ("invalid", None),
252        ] {
253            assert_eq!(parse_csv_bool(input), *output);
254        }
255    }
256
257    #[test]
258    fn test_parse_csv_opt_string_filter() {
259        for (input, output) in [
260            ("", None),
261            ("value", Some("value")),
262            ("*", Some("*")),
263            ("Foo", Some("foo")),
264            ("  ok ", Some("ok")),
265            (" NO  ", Some("no")),
266        ] {
267            assert_eq!(
268                parse_csv_opt_string_filter(input)
269                    .as_ref()
270                    .map(|f| f.as_ref()),
271                output,
272            );
273        }
274    }
275
276    #[test]
277    fn test_parse_csv_opt_string_filter_is_any() {
278        let filter = parse_csv_opt_string_filter("*").unwrap();
279        assert!(venndb::Any::is_any(&filter));
280    }
281
282    #[test]
283    fn test_parse_csv_row_happy_path() {
284        for (input, output) in [
285            // most minimal
286            (
287                "id,,,,,,,,,,authority,,,,,,,,",
288                Proxy {
289                    id: NonEmptyString::from_static("id"),
290                    address: ProxyAddress::from_str("authority").unwrap(),
291                    tcp: false,
292                    udp: false,
293                    http: false,
294                    https: false,
295                    socks5: false,
296                    socks5h: false,
297                    datacenter: false,
298                    residential: false,
299                    mobile: false,
300                    pool_id: None,
301                    continent: None,
302                    country: None,
303                    state: None,
304                    city: None,
305                    carrier: None,
306                    asn: None,
307                },
308            ),
309            // more happy row tests
310            (
311                "id,true,false,true,,false,,true,false,true,authority,pool_id,,country,,city,carrier,,Basic dXNlcm5hbWU6cGFzc3dvcmQ=",
312                Proxy {
313                    id: NonEmptyString::from_static("id"),
314                    address: ProxyAddress::from_str("username:password@authority").unwrap(),
315                    tcp: true,
316                    udp: false,
317                    http: true,
318                    https: false,
319                    socks5: false,
320                    socks5h: false,
321                    datacenter: true,
322                    residential: false,
323                    mobile: true,
324                    pool_id: Some("pool_id".into()),
325                    continent: None,
326                    country: Some("country".into()),
327                    state: None,
328                    city: Some("city".into()),
329                    carrier: Some("carrier".into()),
330                    asn: None,
331                },
332            ),
333            (
334                "123,1,0,False,,True,,null,false,true,host:1234,,americas,*,*,*,carrier,13335,",
335                Proxy {
336                    id: NonEmptyString::from_static("123"),
337                    address: ProxyAddress::from_str("host:1234").unwrap(),
338                    tcp: true,
339                    udp: false,
340                    http: false,
341                    https: false,
342                    socks5: true,
343                    socks5h: false,
344                    datacenter: false,
345                    residential: false,
346                    mobile: true,
347                    pool_id: None,
348                    continent: Some("americas".into()),
349                    country: Some("*".into()),
350                    state: Some("*".into()),
351                    city: Some("*".into()),
352                    carrier: Some("carrier".into()),
353                    asn: Some(Asn::from_static(13335)),
354                },
355            ),
356            (
357                "123,1,0,False,,True,,null,false,true,host:1234,,europe,*,,*,carrier,0",
358                Proxy {
359                    id: NonEmptyString::from_static("123"),
360                    address: ProxyAddress::from_str("host:1234").unwrap(),
361                    tcp: true,
362                    udp: false,
363                    http: false,
364                    https: false,
365                    socks5: true,
366                    socks5h: false,
367                    datacenter: false,
368                    residential: false,
369                    mobile: true,
370                    pool_id: None,
371                    continent: Some("europe".into()),
372                    country: Some("*".into()),
373                    state: None,
374                    city: Some("*".into()),
375                    carrier: Some("carrier".into()),
376                    asn: Some(Asn::unspecified()),
377                },
378            ),
379            (
380                "foo,1,0,1,,0,,1,0,0,bar,baz,,US,,,,",
381                Proxy {
382                    id: NonEmptyString::from_static("foo"),
383                    address: ProxyAddress::from_str("bar").unwrap(),
384                    tcp: true,
385                    udp: false,
386                    http: true,
387                    https: false,
388                    socks5: false,
389                    socks5h: false,
390                    datacenter: true,
391                    residential: false,
392                    mobile: false,
393                    pool_id: Some("baz".into()),
394                    continent: None,
395                    country: Some("us".into()),
396                    state: None,
397                    city: None,
398                    carrier: None,
399                    asn: None,
400                },
401            ),
402        ] {
403            let proxy = parse_csv_row(input).unwrap();
404            assert_eq!(proxy.id, output.id);
405            assert_eq!(proxy.address, output.address);
406            assert_eq!(proxy.tcp, output.tcp);
407            assert_eq!(proxy.udp, output.udp);
408            assert_eq!(proxy.http, output.http);
409            assert_eq!(proxy.socks5, output.socks5);
410            assert_eq!(proxy.datacenter, output.datacenter);
411            assert_eq!(proxy.residential, output.residential);
412            assert_eq!(proxy.mobile, output.mobile);
413            assert_eq!(proxy.pool_id, output.pool_id);
414            assert_eq!(proxy.continent, output.continent);
415            assert_eq!(proxy.country, output.country);
416            assert_eq!(proxy.state, output.state);
417            assert_eq!(proxy.city, output.city);
418            assert_eq!(proxy.carrier, output.carrier);
419            assert_eq!(proxy.asn, output.asn);
420        }
421    }
422
423    #[test]
424    fn test_parse_csv_row_mistakes() {
425        for input in [
426            // garbage rows
427            "",
428            ",",
429            ",,,,,,",
430            ",,,,,,,,,,,,,,,,,,,,",
431            ",,,,,,,,,,,,,,,,,,,,,,",
432            ",,,,,,,,,,,,,,,,,,,,,,,",
433            // too many rows
434            "id,true,false,true,false,true,false,true,authority,pool_id,continent,country,state,city,carrier,15169,Basic dXNlcm5hbWU6cGFzc3dvcmQ=,",
435            // missing authority
436            "id,,,,,,,,,,,,,,,,",
437            // missing proxy id
438            ",,,,,,,,authority,,,,,,,,",
439            // invalid bool values
440            "id,foo,,,,,,,,,authority,,,,,,,,",
441            "id,,foo,,,,,,,,authority,,,,,,,,",
442            "id,,,foo,,,,,,,authority,,,,,,,,",
443            "id,,,,,foo,,,,,authority,,,,,,,,",
444            "id,,,,,,foo,,,,authority,,,,,,,,",
445            "id,,,,,,,,foo,,authority,,,,,,,,",
446            "id,,,,,,,foo,authority,,,,,,,,",
447            // invalid credentials
448            "id,,,,,,,,authority,,,,,:foo",
449        ] {
450            assert!(parse_csv_row(input).is_none(), "input: {input}");
451        }
452    }
453
454    #[tokio::test]
455    async fn test_proxy_csv_row_reader_happy_one_row() {
456        let mut reader = ProxyCsvRowReader::raw(
457            "id,true,false,true,,false,,true,false,true,authority,pool_id,continent,country,state,city,carrier,13335,Basic dXNlcm5hbWU6cGFzc3dvcmQ=",
458        );
459        let proxy = reader.next().await.unwrap().unwrap();
460
461        assert_eq!(proxy.id, "id");
462        assert_eq!(
463            proxy.address,
464            ProxyAddress::from_str("username:password@authority").unwrap()
465        );
466        assert!(proxy.tcp);
467        assert!(!proxy.udp);
468        assert!(proxy.http);
469        assert!(!proxy.socks5);
470        assert!(proxy.datacenter);
471        assert!(!proxy.residential);
472        assert!(proxy.mobile);
473        assert_eq!(proxy.pool_id, Some("pool_id".into()));
474        assert_eq!(proxy.continent, Some("continent".into()));
475        assert_eq!(proxy.country, Some("country".into()));
476        assert_eq!(proxy.state, Some("state".into()));
477        assert_eq!(proxy.city, Some("city".into()));
478        assert_eq!(proxy.carrier, Some("carrier".into()));
479        assert_eq!(proxy.asn, Some(Asn::from_static(13335)));
480
481        // no more rows to read
482        assert!(reader.next().await.unwrap().is_none());
483    }
484
485    #[tokio::test]
486    async fn test_proxy_csv_row_reader_happy_multi_row() {
487        let mut reader = ProxyCsvRowReader::raw(
488            "id,true,false,false,true,true,false,true,false,true,authority,pool_id,continent,country,state,city,carrier,42,Basic dXNlcm5hbWU6cGFzc3dvcmQ=\nid2,1,0,0,0,0,0,1,0,0,authority2,pool_id2,continent2,country2,state2,city2,carrier2,1",
489        );
490
491        let proxy = reader.next().await.unwrap().unwrap();
492        assert_eq!(proxy.id, "id");
493        assert_eq!(
494            proxy.address,
495            ProxyAddress::from_str("username:password@authority").unwrap()
496        );
497        assert!(proxy.tcp);
498        assert!(!proxy.udp);
499        assert!(!proxy.http);
500        assert!(proxy.https);
501        assert!(proxy.socks5);
502        assert!(!proxy.socks5h);
503        assert!(proxy.datacenter);
504        assert!(!proxy.residential);
505        assert!(proxy.mobile);
506        assert_eq!(proxy.pool_id, Some("pool_id".into()));
507        assert_eq!(proxy.continent, Some("continent".into()));
508        assert_eq!(proxy.country, Some("country".into()));
509        assert_eq!(proxy.state, Some("state".into()));
510        assert_eq!(proxy.city, Some("city".into()));
511        assert_eq!(proxy.carrier, Some("carrier".into()));
512        assert_eq!(proxy.asn, Some(Asn::from_static(42)));
513
514        let proxy = reader.next().await.unwrap().unwrap();
515
516        assert_eq!(proxy.id, "id2");
517        assert_eq!(proxy.address, ProxyAddress::from_str("authority2").unwrap());
518        assert!(proxy.tcp);
519        assert!(!proxy.udp);
520        assert!(!proxy.http);
521        assert!(!proxy.https);
522        assert!(!proxy.socks5);
523        assert!(!proxy.socks5h);
524        assert!(proxy.datacenter);
525        assert!(!proxy.residential);
526        assert!(!proxy.mobile);
527        assert_eq!(proxy.pool_id, Some("pool_id2".into()));
528        assert_eq!(proxy.continent, Some("continent2".into()));
529        assert_eq!(proxy.country, Some("country2".into()));
530        assert_eq!(proxy.city, Some("city2".into()));
531        assert_eq!(proxy.state, Some("state2".into()));
532        assert_eq!(proxy.carrier, Some("carrier2".into()));
533        assert_eq!(proxy.asn, Some(Asn::from_static(1)));
534
535        // no more rows to read
536        assert!(reader.next().await.unwrap().is_none());
537    }
538
539    #[tokio::test]
540    async fn test_proxy_csv_row_reader_failure_empty_data() {
541        let mut reader = ProxyCsvRowReader::raw("");
542        assert!(reader.next().await.unwrap().is_none());
543    }
544
545    #[tokio::test]
546    async fn test_proxy_csv_row_reader_failure_invalid_row() {
547        let mut reader = ProxyCsvRowReader::raw(",,,,,,,,,,,");
548        assert!(reader.next().await.is_err());
549    }
550
551    #[test]
552    fn test_proxy_is_match_happy_path_proxy_with_any_filter_string_cases() {
553        let proxy = parse_csv_row("id,1,,1,,,,,,,authority,*,*,*,*,*,*,0").unwrap();
554        let ctx = ProxyContext {
555            protocol: TransportProtocol::Tcp,
556        };
557
558        for filter in [
559            ProxyFilter::default(),
560            ProxyFilter {
561                pool_id: Some(vec![StringFilter::new("pool_a")]),
562                country: Some(vec![StringFilter::new("country_a")]),
563                city: Some(vec![StringFilter::new("city_a")]),
564                carrier: Some(vec![StringFilter::new("carrier_a")]),
565                ..Default::default()
566            },
567            ProxyFilter {
568                pool_id: Some(vec![StringFilter::new("pool_a")]),
569                ..Default::default()
570            },
571            ProxyFilter {
572                continent: Some(vec![StringFilter::new("continent_a")]),
573                ..Default::default()
574            },
575            ProxyFilter {
576                country: Some(vec![StringFilter::new("country_a")]),
577                ..Default::default()
578            },
579            ProxyFilter {
580                state: Some(vec![StringFilter::new("state_a")]),
581                ..Default::default()
582            },
583            ProxyFilter {
584                city: Some(vec![StringFilter::new("city_a")]),
585                carrier: Some(vec![StringFilter::new("carrier_a")]),
586                ..Default::default()
587            },
588            ProxyFilter {
589                carrier: Some(vec![StringFilter::new("carrier_a")]),
590                ..Default::default()
591            },
592        ] {
593            assert!(proxy.is_match(&ctx, &filter), "filter: {filter:?}");
594        }
595    }
596
597    #[test]
598    fn test_proxy_is_match_happy_path_proxy_with_any_filters_cases() {
599        let proxy =
600            parse_csv_row("id,1,,1,,,,,,,authority,pool,continent,country,state,city,carrier,42")
601                .unwrap();
602        let ctx = ProxyContext {
603            protocol: TransportProtocol::Tcp,
604        };
605
606        for filter in [
607            ProxyFilter::default(),
608            ProxyFilter {
609                pool_id: Some(vec![StringFilter::new("*")]),
610                ..Default::default()
611            },
612            ProxyFilter {
613                continent: Some(vec![StringFilter::new("*")]),
614                ..Default::default()
615            },
616            ProxyFilter {
617                country: Some(vec![StringFilter::new("*")]),
618                ..Default::default()
619            },
620            ProxyFilter {
621                state: Some(vec![StringFilter::new("*")]),
622                ..Default::default()
623            },
624            ProxyFilter {
625                city: Some(vec![StringFilter::new("*")]),
626                ..Default::default()
627            },
628            ProxyFilter {
629                carrier: Some(vec![StringFilter::new("*")]),
630                ..Default::default()
631            },
632            ProxyFilter {
633                pool_id: Some(vec![StringFilter::new("pool")]),
634                continent: Some(vec![StringFilter::new("continent")]),
635                country: Some(vec![StringFilter::new("country")]),
636                state: Some(vec![StringFilter::new("state")]),
637                city: Some(vec![StringFilter::new("city")]),
638                carrier: Some(vec![StringFilter::new("carrier")]),
639                asn: Some(vec![Asn::from_static(42)]),
640                ..Default::default()
641            },
642            ProxyFilter {
643                pool_id: Some(vec![StringFilter::new("*")]),
644                country: Some(vec![StringFilter::new("country")]),
645                city: Some(vec![StringFilter::new("city")]),
646                carrier: Some(vec![StringFilter::new("carrier")]),
647                ..Default::default()
648            },
649            ProxyFilter {
650                pool_id: Some(vec![StringFilter::new("pool")]),
651                country: Some(vec![StringFilter::new("*")]),
652                city: Some(vec![StringFilter::new("city")]),
653                carrier: Some(vec![StringFilter::new("carrier")]),
654                ..Default::default()
655            },
656            ProxyFilter {
657                pool_id: Some(vec![StringFilter::new("pool")]),
658                country: Some(vec![StringFilter::new("country")]),
659                city: Some(vec![StringFilter::new("*")]),
660                carrier: Some(vec![StringFilter::new("carrier")]),
661                ..Default::default()
662            },
663            ProxyFilter {
664                pool_id: Some(vec![StringFilter::new("pool")]),
665                country: Some(vec![StringFilter::new("country")]),
666                city: Some(vec![StringFilter::new("city")]),
667                carrier: Some(vec![StringFilter::new("*")]),
668                ..Default::default()
669            },
670            ProxyFilter {
671                continent: Some(vec![StringFilter::new("*")]),
672                ..Default::default()
673            },
674        ] {
675            assert!(proxy.is_match(&ctx, &filter), "filter: {filter:?}");
676        }
677    }
678}