ip2location_bin_format/
querier.rs

1use core::{cmp::max, future::Future, ops::ControlFlow, pin::Pin};
2use std::{
3    io::{Cursor, Error as IoError, SeekFrom},
4    net::{IpAddr, Ipv4Addr, Ipv6Addr},
5};
6
7use deadpool::unmanaged::{Pool, PoolError};
8use futures_util::{AsyncRead, AsyncReadExt as _, AsyncSeek, AsyncSeekExt as _};
9
10use crate::{
11    content::{querier::FillError as ContentFillError, Querier as ContentQuerier},
12    header::{
13        parser::ParseError as HeaderParseError, Parser as HeaderParser, Schema as HeaderSchema,
14        HEADER_LEN,
15    },
16    index::{
17        querier::BuildError as IndexBuildError, V4Querier as IndexV4Querier,
18        V6Querier as IndexV6Querier, INDEX_LEN,
19    },
20    record_field::{RecordField, RecordFieldContents},
21    records::{
22        querier::v4_querier::NewError as RecordsV4QuerierNewError,
23        querier::v6_querier::NewError as RecordsV6QuerierNewError,
24        querier::Error as RecordsQueryError, V4Querier as RecordsV4Querier,
25        V6Querier as RecordsV6Querier,
26    },
27};
28
29//
30pub struct Querier<S> {
31    pub header: HeaderSchema,
32    pub index_v4: IndexV4Querier,
33    pub index_v6: Option<IndexV6Querier>,
34    pub records_v4_pool: Pool<RecordsV4Querier<S>>,
35    pub records_v6_pool: Option<Pool<RecordsV6Querier<S>>>,
36    pub content_pool: Pool<ContentQuerier<S>>,
37}
38
39impl<S> core::fmt::Debug for Querier<S>
40where
41    Pool<RecordsV4Querier<S>>: core::fmt::Debug,
42    Option<Pool<RecordsV6Querier<S>>>: core::fmt::Debug,
43    Pool<ContentQuerier<S>>: core::fmt::Debug,
44{
45    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
46        f.debug_struct("Querier")
47            .field("header", &self.header)
48            .field("index_v4", &self.index_v4)
49            .field("index_v6", &self.index_v6)
50            .field("records_v4_pool", &self.records_v4_pool)
51            .field("records_v6_pool", &self.records_v6_pool)
52            .field("content_pool", &self.content_pool)
53            .finish()
54    }
55}
56
57//
58//
59//
60impl<S> Querier<S>
61where
62    S: AsyncSeek + AsyncRead + Unpin,
63{
64    pub async fn new<F>(mut stream_repeater: F, pool_max_size: usize) -> Result<Self, NewError>
65    where
66        F: FnMut() -> Pin<Box<dyn Future<Output = Result<S, IoError>> + Send + 'static>>,
67    {
68        let pool_max_size = max(1, pool_max_size);
69
70        let mut buf = vec![0; 1024 * 8];
71
72        //
73        let mut stream = stream_repeater().await.map_err(NewError::OpenFailed)?;
74
75        //
76        let header = {
77            let mut parser = HeaderParser::new();
78            let mut n_read = 0;
79            let mut n_parsed = 0;
80            loop {
81                let n = stream
82                    .read(&mut buf[n_read..n_read + HEADER_LEN as usize])
83                    .await
84                    .map_err(NewError::ReadFailed)?;
85
86                if n == 0 {
87                    return Err(NewError::ReadOtherError("header parsing is not completed"));
88                }
89
90                n_read += n;
91
92                match parser
93                    .parse(&mut Cursor::new(&buf[n_parsed..n_read]))
94                    .map_err(NewError::HeaderParseFailed)?
95                {
96                    ControlFlow::Continue(n) => {
97                        n_parsed += n;
98                        continue;
99                    }
100                    ControlFlow::Break((_n, header)) => {
101                        break header;
102                    }
103                }
104            }
105        };
106
107        //
108        stream
109            .seek(SeekFrom::Start(header.total_size as u64))
110            .await
111            .map_err(NewError::SeekFailed)?;
112        let n = stream
113            .read(&mut buf[..1])
114            .await
115            .map_err(NewError::ReadFailed)?;
116        if n != 0 {
117            return Err(NewError::TotalSizeMissing);
118        }
119
120        //
121        let index_v4 = {
122            let mut builder = IndexV4Querier::builder();
123            let mut n_max_appended = INDEX_LEN as usize;
124            stream
125                .seek(SeekFrom::Start(header.v4_index_seek_from_start()))
126                .await
127                .map_err(NewError::ReadFailed)?;
128            loop {
129                let n = stream
130                    .read(&mut buf[..])
131                    .await
132                    .map_err(NewError::ReadFailed)?;
133
134                if n == 0 {
135                    return Err(NewError::ReadOtherError(
136                        "index_v4 building is not completed",
137                    ));
138                }
139
140                if n < n_max_appended {
141                    builder.append(&buf[..n]);
142
143                    n_max_appended -= n;
144                    continue;
145                } else {
146                    builder.append(&buf[..n_max_appended]);
147
148                    break builder
149                        .finish::<IndexV4Querier>()
150                        .map_err(NewError::IndexV4BuildFailed)?;
151                }
152            }
153        };
154
155        //
156        let mut index_v6 = None;
157        #[allow(clippy::unnecessary_operation)]
158        {
159            if let Some(v6_index_seek_from_start) = header.v6_index_seek_from_start() {
160                let mut builder = IndexV6Querier::builder();
161                let mut n_max_appended = INDEX_LEN as usize;
162                stream
163                    .seek(SeekFrom::Start(v6_index_seek_from_start))
164                    .await
165                    .map_err(NewError::ReadFailed)?;
166                let index_v6_tmp = loop {
167                    let n = stream
168                        .read(&mut buf[..])
169                        .await
170                        .map_err(NewError::ReadFailed)?;
171
172                    if n == 0 {
173                        return Err(NewError::ReadOtherError(
174                            "index_v6 building is not completed",
175                        ));
176                    }
177
178                    if n < n_max_appended {
179                        builder.append(&buf[..n]);
180
181                        n_max_appended -= n;
182                        continue;
183                    } else {
184                        builder.append(&buf[..n_max_appended]);
185
186                        break builder
187                            .finish::<IndexV6Querier>()
188                            .map_err(NewError::IndexV6BuildFailed)?;
189                    }
190                };
191                index_v6 = Some(index_v6_tmp);
192            }
193        };
194
195        let records_v4_pool = {
196            let mut pool_objs = vec![];
197
198            for _ in 0..pool_max_size {
199                let mut stream = stream_repeater().await.map_err(NewError::OpenFailed)?;
200
201                stream
202                    .seek(SeekFrom::Start(header.total_size as u64))
203                    .await
204                    .map_err(NewError::SeekFailed)?;
205                let n = stream
206                    .read(&mut buf[..1])
207                    .await
208                    .map_err(NewError::ReadFailed)?;
209                if n != 0 {
210                    return Err(NewError::TotalSizeMissing);
211                }
212
213                //
214                let pool_obj = RecordsV4Querier::new(stream, header)
215                    .map_err(NewError::RecordsV4QuerierNewFailed)?;
216
217                pool_objs.push(pool_obj);
218            }
219
220            Pool::from(pool_objs)
221        };
222
223        let mut records_v6_pool = None;
224        #[allow(clippy::unnecessary_operation)]
225        {
226            if header.has_v6() {
227                let mut pool_objs = vec![];
228
229                for _ in 0..pool_max_size {
230                    let mut stream = stream_repeater().await.map_err(NewError::OpenFailed)?;
231
232                    stream
233                        .seek(SeekFrom::Start(header.total_size as u64))
234                        .await
235                        .map_err(NewError::SeekFailed)?;
236                    let n = stream
237                        .read(&mut buf[..1])
238                        .await
239                        .map_err(NewError::ReadFailed)?;
240                    if n != 0 {
241                        return Err(NewError::TotalSizeMissing);
242                    }
243
244                    //
245                    let pool_obj = RecordsV6Querier::new(stream, header)
246                        .map_err(NewError::RecordsV6QuerierNewFailed)?;
247
248                    pool_objs.push(pool_obj);
249                }
250
251                records_v6_pool = Some(Pool::from(pool_objs))
252            }
253        };
254
255        let content_pool = {
256            let mut pool_objs = vec![];
257
258            for _ in 0..pool_max_size {
259                let mut stream = stream_repeater().await.map_err(NewError::OpenFailed)?;
260
261                stream
262                    .seek(SeekFrom::Start(header.total_size as u64))
263                    .await
264                    .map_err(NewError::SeekFailed)?;
265                let n = stream
266                    .read(&mut buf[..1])
267                    .await
268                    .map_err(NewError::ReadFailed)?;
269                if n != 0 {
270                    return Err(NewError::TotalSizeMissing);
271                }
272
273                //
274                let pool_obj = ContentQuerier::new(stream);
275
276                pool_objs.push(pool_obj);
277            }
278
279            Pool::from(pool_objs)
280        };
281
282        //
283        Ok(Self {
284            header,
285            index_v4,
286            index_v6,
287            records_v4_pool,
288            records_v6_pool,
289            content_pool,
290        })
291    }
292}
293
294//
295#[derive(Debug)]
296pub enum NewError {
297    OpenFailed(IoError),
298    SeekFailed(IoError),
299    ReadFailed(IoError),
300    ReadOtherError(&'static str),
301    HeaderParseFailed(HeaderParseError),
302    TotalSizeMissing,
303    IndexV4BuildFailed(IndexBuildError),
304    IndexV6BuildFailed(IndexBuildError),
305    RecordsV4QuerierNewFailed(RecordsV4QuerierNewError),
306    RecordsV6QuerierNewFailed(RecordsV6QuerierNewError),
307}
308
309impl core::fmt::Display for NewError {
310    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
311        write!(f, "{self:?}")
312    }
313}
314
315impl std::error::Error for NewError {}
316
317//
318//
319//
320impl<S> Querier<S>
321where
322    S: AsyncSeek + AsyncRead + Unpin,
323{
324    pub async fn lookup(
325        &self,
326        ip: IpAddr,
327        selected_fields: Option<&[RecordField]>,
328    ) -> Result<Option<(IpAddr, IpAddr, RecordFieldContents)>, LookupError> {
329        match ip {
330            IpAddr::V4(ip) => self.lookup_ipv4(ip, selected_fields).await,
331            IpAddr::V6(ip) => self.lookup_ipv6(ip, selected_fields).await,
332        }
333    }
334
335    pub async fn lookup_ipv4(
336        &self,
337        ip: Ipv4Addr,
338        selected_fields: Option<&[RecordField]>,
339    ) -> Result<Option<(IpAddr, IpAddr, RecordFieldContents)>, LookupError> {
340        let position_range = self.index_v4.query(ip);
341
342        if position_range.end == 0 {
343            return Ok(None);
344        }
345
346        //
347        let mut records_v4 = self
348            .records_v4_pool
349            .get()
350            .await
351            .map_err(LookupError::PoolGetFailed)?;
352        let (ip_from, ip_to, mut record_field_contents) = match records_v4
353            .query(ip, position_range)
354            .await
355            .map_err(LookupError::RecordsQueryFailed)?
356        {
357            Some(x) => x,
358            None => return Ok(None),
359        };
360
361        if let Some(selected_fields) = selected_fields {
362            record_field_contents.select(selected_fields);
363        }
364
365        //
366        let mut content = self
367            .content_pool
368            .get()
369            .await
370            .map_err(LookupError::PoolGetFailed)?;
371
372        content
373            .fill(&mut record_field_contents)
374            .await
375            .map_err(LookupError::ContentFillFailed)?;
376
377        Ok(Some((ip_from, ip_to, record_field_contents)))
378    }
379
380    pub async fn lookup_ipv6(
381        &self,
382        ip: Ipv6Addr,
383        selected_fields: Option<&[RecordField]>,
384    ) -> Result<Option<(IpAddr, IpAddr, RecordFieldContents)>, LookupError> {
385        if let Some(ip) = ip.to_ipv4() {
386            return self.lookup_ipv4(ip, selected_fields).await.map(|x| {
387                x.map(|(ip_from, ip_to, record_field_contents)| {
388                    (
389                        match ip_from {
390                            IpAddr::V4(ip) => ip.to_ipv6_mapped().into(),
391                            IpAddr::V6(ip) => {
392                                debug_assert!(false, "unreachable");
393                                ip.into()
394                            }
395                        },
396                        match ip_to {
397                            IpAddr::V4(ip) => ip.to_ipv6_mapped().into(),
398                            IpAddr::V6(ip) => {
399                                debug_assert!(false, "unreachable");
400                                ip.into()
401                            }
402                        },
403                        record_field_contents,
404                    )
405                })
406            });
407        }
408
409        let position_range = self
410            .index_v6
411            .as_ref()
412            .map(|x| x.query(ip))
413            .unwrap_or_default();
414
415        if position_range.end == 0 {
416            return Ok(None);
417        }
418
419        let (ip_from, ip_to, mut record_field_contents) = match self.records_v6_pool.as_ref() {
420            Some(records_v6_pool) => {
421                //
422                let mut records_v6 = records_v6_pool
423                    .get()
424                    .await
425                    .map_err(LookupError::PoolGetFailed)?;
426
427                match records_v6
428                    .query(ip, position_range)
429                    .await
430                    .map_err(LookupError::RecordsQueryFailed)?
431                {
432                    Some(x) => x,
433                    None => return Ok(None),
434                }
435            }
436            None => return Ok(None),
437        };
438
439        if let Some(selected_fields) = selected_fields {
440            record_field_contents.select(selected_fields);
441        }
442
443        //
444        let mut content = self
445            .content_pool
446            .get()
447            .await
448            .map_err(LookupError::PoolGetFailed)?;
449
450        content
451            .fill(&mut record_field_contents)
452            .await
453            .map_err(LookupError::ContentFillFailed)?;
454
455        Ok(Some((ip_from, ip_to, record_field_contents)))
456    }
457}
458
459//
460#[derive(Debug)]
461pub enum LookupError {
462    PoolGetFailed(PoolError),
463    RecordsQueryFailed(RecordsQueryError),
464    ContentFillFailed(ContentFillError),
465}
466
467impl core::fmt::Display for LookupError {
468    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
469        write!(f, "{self:?}")
470    }
471}
472
473impl std::error::Error for LookupError {}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478
479    use async_compat::Compat;
480    use futures_util::TryFutureExt as _;
481    use tokio::fs::File as TokioFile;
482
483    use crate::test_helper::{ip2location_bin_files, ip2proxy_bin_files};
484
485    #[tokio::test]
486    async fn test_new_pool_max_size() -> Result<(), Box<dyn std::error::Error>> {
487        let path = match ip2location_bin_files().first().cloned() {
488            Some(x) => x,
489            None => return Ok(()),
490        };
491
492        for (pool_max_size, pool_assert_size) in &[(0, 1), (1, 1), (2, 2)] {
493            let q = Querier::new(
494                || Box::pin(TokioFile::open(path.clone()).map_ok(Compat::new)),
495                *pool_max_size,
496            )
497            .await?;
498            assert_eq!(q.content_pool.status().size, *pool_assert_size);
499        }
500
501        Ok(())
502    }
503
504    #[tokio::test]
505    async fn test_new_and_lookup() -> Result<(), Box<dyn std::error::Error>> {
506        let ips: &[IpAddr] = &[
507            Ipv4Addr::new(0, 0, 0, 0).into(),
508            Ipv4Addr::new(255, 255, 255, 255).into(),
509            Ipv6Addr::new(0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0).into(),
510            Ipv6Addr::new(
511                0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
512            )
513            .into(),
514        ];
515
516        for path in ip2location_bin_files().iter() {
517            let q = Querier::new(
518                || Box::pin(TokioFile::open(path.clone()).map_ok(Compat::new)),
519                2,
520            )
521            .await?;
522
523            for ip in ips {
524                let ret = q.lookup(*ip, None).await?;
525                assert!(ret.is_none());
526            }
527
528            if path
529                .as_os_str()
530                .to_str()
531                .map(|x| x.contains("/20220329") && x.contains("IPV6.BIN"))
532                == Some(true)
533            {
534                q.lookup(
535                    Ipv6Addr::from(58569107296622255421594597096899477504).into(),
536                    None,
537                )
538                .await?
539                .unwrap();
540            }
541
542            if path
543                .as_os_str()
544                .to_str()
545                .map(|x| x.contains("/ip2location-sample") && x.contains("/IP-"))
546                == Some(true)
547            {
548                let ret = q.lookup(Ipv4Addr::new(8, 8, 8, 8).into(), None).await?;
549                assert!(ret.is_some());
550            }
551        }
552
553        for path in ip2proxy_bin_files().iter() {
554            let q = Querier::new(
555                || Box::pin(TokioFile::open(path.clone()).map_ok(Compat::new)),
556                2,
557            )
558            .await?;
559
560            for ip in ips {
561                let ret = q.lookup(*ip, None).await?;
562                assert!(ret.is_none());
563            }
564
565            let out = q
566                .lookup(
567                    Ipv6Addr::from(58569071813452613185929873510317667680).into(),
568                    None,
569                )
570                .await?;
571            assert!(out.is_some());
572        }
573
574        Ok(())
575    }
576}