tantivy_uffd/
lib.rs

1mod query_len;
2mod uffd;
3mod vec_writer;
4
5use crate::{uffd::round_up_to_page, vec_writer::VecWriter};
6use dashmap::DashMap;
7use log::info;
8use nix::sys::mman::{mmap, MapFlags, ProtFlags};
9use std::{
10    ops::{Deref, Range},
11    path::Path,
12    slice,
13    sync::Arc,
14};
15use tantivy::{
16    directory::{
17        error::{DeleteError, OpenReadError, OpenWriteError},
18        WatchHandle, WritePtr,
19    },
20    Directory,
21};
22use tantivy_common::{file_slice::FileHandle, HasLen, OwnedBytes, StableDeref};
23use tokio::runtime::Runtime;
24use uffd::UffdFile;
25use userfaultfd::UffdBuilder;
26
27thread_local! {
28    pub(crate) static BLOCKING_HTTP_CLIENT: reqwest::blocking::Client = reqwest::blocking::Client::new();
29}
30
31#[derive(Clone)]
32struct MmapArc {
33    slice: &'static [u8],
34}
35
36impl Deref for MmapArc {
37    type Target = [u8];
38
39    #[inline]
40    fn deref(&self) -> &[u8] {
41        self.slice
42    }
43}
44unsafe impl StableDeref for MmapArc {}
45
46#[derive(Debug, Clone, Hash, Eq, PartialEq)]
47struct CacheKey {
48    base_url: String,
49    path: String,
50    chunk: usize,
51}
52
53#[derive(Debug, Clone)]
54struct HttpFileHandle<const CHUNK_SIZE: usize> {
55    owned_bytes: Arc<OwnedBytes>,
56    _uffd_file: Option<Arc<UffdFile<CHUNK_SIZE>>>,
57}
58
59impl<const CHUNK_SIZE: usize> HttpFileHandle<CHUNK_SIZE> {
60    pub(crate) fn new(runtime: Arc<Runtime>, file_size: usize, artifact_url: String) -> Self {
61        let mmap_len = round_up_to_page(file_size, CHUNK_SIZE);
62        let uffd = UffdBuilder::new()
63            .close_on_exec(true)
64            .user_mode_only(true)
65            .create()
66            .unwrap();
67
68        let addr = unsafe {
69            mmap(
70                None,
71                mmap_len.try_into().unwrap(),
72                ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
73                MapFlags::MAP_PRIVATE | MapFlags::MAP_ANONYMOUS | MapFlags::MAP_NORESERVE,
74                None::<std::os::fd::BorrowedFd>,
75                0,
76            )
77            .expect("mmap")
78        };
79
80        let mmap_ptr = addr as usize;
81
82        uffd.register(addr, mmap_len).unwrap();
83
84        let uffd_file = Arc::new(UffdFile::new(
85            Arc::new(uffd),
86            runtime,
87            mmap_ptr,
88            artifact_url.clone(),
89        ));
90        {
91            let uffd_file = uffd_file.clone();
92            std::thread::spawn(move || {
93                uffd_file.handle_faults();
94            });
95        }
96        let owned_bytes = Arc::new(OwnedBytes::new(MmapArc {
97            slice: unsafe { slice::from_raw_parts(mmap_ptr as *const u8, file_size) },
98        }));
99
100        Self {
101            owned_bytes,
102            _uffd_file: Some(uffd_file),
103        }
104    }
105}
106
107impl<const CHUNK_SIZE: usize> FileHandle for HttpFileHandle<CHUNK_SIZE> {
108    fn read_bytes(&self, range: Range<usize>) -> std::io::Result<OwnedBytes> {
109        Ok(self.owned_bytes.slice(range))
110    }
111}
112
113impl<const CHUNK_SIZE: usize> HasLen for HttpFileHandle<CHUNK_SIZE> {
114    fn len(&self) -> usize {
115        self.owned_bytes.len()
116    }
117}
118
119/// HTTP remote directory for tantivy. The directory is read-only, and is accessed on-demand by HTTP
120/// range requests. The directory is backed by a large anonymous memory map, and pages are marked as
121/// available to the kernel with MADV_FREE, meaning that the kernel can reclaim the memory if
122/// needed. However, in situations of low memory pressure, previously fetched index pieces will
123/// remain in memory for fast subsequent access. In practice, this means that several searches sent
124/// in quick succession as a user is typing out a query will warm the cache for the final search,
125/// making it feel faster.
126#[derive(Debug, Clone)]
127pub struct RemoteDirectory<const CHUNK_SIZE: usize> {
128    base_url: String,
129    file_handle_cache: Arc<DashMap<String, Arc<HttpFileHandle<CHUNK_SIZE>>>>,
130    atomic_read_cache: Arc<DashMap<String, Vec<u8>>>,
131    uffd_runtime: Arc<Runtime>,
132}
133
134impl<const CHUNK_SIZE: usize> RemoteDirectory<CHUNK_SIZE> {
135    /// Create a new remote directory with the given base URL. The base URL can have a trailing
136    /// slash or not, but a trailing slash will be appended if one is missing. For example, a
137    /// request for `meta.json` on a directory with the base URL `http://localhost:8080` will result
138    /// in a GET request to `http://localhost:8080/meta.json`.
139    pub fn new(base_url: &str) -> Self {
140        let rt = Runtime::new().unwrap();
141
142        Self {
143            base_url: base_url.to_string(),
144            file_handle_cache: Arc::new(DashMap::new()),
145            atomic_read_cache: Arc::new(DashMap::new()),
146            uffd_runtime: Arc::new(rt),
147        }
148    }
149
150    fn format_url(&self, path: &Path) -> String {
151        if self.base_url.ends_with('/') {
152            format!("{}{}", self.base_url, path.display())
153        } else {
154            format!("{}/{}", self.base_url, path.display())
155        }
156    }
157}
158
159impl<const CHUNK_SIZE: usize> Directory for RemoteDirectory<CHUNK_SIZE> {
160    fn get_file_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>, OpenReadError> {
161        let url = self.format_url(path);
162        {
163            if let Some(file_handle) = self.file_handle_cache.get(&url) {
164                return Ok(file_handle.clone());
165            }
166        }
167        let file_len = query_len::len(&url);
168        let len = round_up_to_page(file_len, CHUNK_SIZE);
169
170        if len == 0 {
171            return Ok(Arc::new(HttpFileHandle::<CHUNK_SIZE> {
172                owned_bytes: Arc::new(OwnedBytes::new(MmapArc { slice: &[] })),
173                _uffd_file: None,
174            }));
175        }
176
177        let file_handle = Arc::new(HttpFileHandle::<CHUNK_SIZE>::new(
178            self.uffd_runtime.clone(),
179            file_len,
180            url.clone(),
181        ));
182        self.file_handle_cache.insert(url, file_handle.clone());
183
184        Ok(file_handle)
185    }
186
187    fn delete(&self, path: &Path) -> Result<(), DeleteError> {
188        if path == Path::new(".tantivy-meta.lock") {
189            return Ok(());
190        }
191
192        Err(DeleteError::IoError {
193            io_error: Arc::new(std::io::Error::new(
194                std::io::ErrorKind::Other,
195                "Delete not supported",
196            )),
197            filepath: path.to_path_buf(),
198        })
199    }
200
201    fn exists(&self, path: &Path) -> Result<bool, OpenReadError> {
202        if path == Path::new(".tantivy-meta.lock") {
203            return Ok(true);
204        }
205        Ok(query_len::len(&self.format_url(path)) > 0)
206    }
207
208    fn open_write(&self, path: &Path) -> Result<WritePtr, OpenWriteError> {
209        if path == Path::new(".tantivy-meta.lock") {
210            return Ok(WritePtr::new(Box::new(VecWriter::new(path.to_path_buf()))));
211        }
212        dbg!(path);
213        Err(OpenWriteError::IoError {
214            io_error: Arc::new(std::io::Error::new(
215                std::io::ErrorKind::Other,
216                "Write not supported",
217            )),
218            filepath: path.to_path_buf(),
219        })
220    }
221
222    fn atomic_read(&self, path: &Path) -> Result<Vec<u8>, OpenReadError> {
223        let url = self.format_url(path);
224        if let Some(bytes) = self.atomic_read_cache.get(&url) {
225            return Ok(bytes.clone());
226        }
227
228        info!("Fetching {} in atomic read.", url);
229        let response = BLOCKING_HTTP_CLIENT.with(|client| client.get(&url).send());
230        let response = if let Err(_e) = response {
231            return Err(OpenReadError::IoError {
232                io_error: Arc::new(std::io::Error::new(
233                    std::io::ErrorKind::Other,
234                    "Fetch failed for atomic read.",
235                )),
236                filepath: path.to_path_buf(),
237            });
238        } else {
239            response.unwrap()
240        };
241        let bytes = response.bytes().unwrap();
242
243        let bytes = bytes.to_vec();
244        self.atomic_read_cache.insert(url, bytes.clone());
245        Ok(bytes)
246    }
247
248    fn atomic_write(&self, _path: &Path, _data: &[u8]) -> std::io::Result<()> {
249        Err(std::io::Error::new(
250            std::io::ErrorKind::Other,
251            "Write not supported",
252        ))
253    }
254
255    fn sync_directory(&self) -> std::io::Result<()> {
256        Ok(())
257    }
258
259    fn watch(
260        &self,
261        _watch_callback: tantivy::directory::WatchCallback,
262    ) -> tantivy::Result<tantivy::directory::WatchHandle> {
263        Ok(WatchHandle::empty())
264    }
265}
266
267#[cfg(test)]
268pub(crate) mod test {
269
270    use std::{path::PathBuf, str::FromStr, sync::OnceLock};
271
272    use tantivy::{directory::ManagedDirectory, doc, schema::Field, Directory, Index};
273    use tiny_http::{Header, Method, Response, Server};
274
275    pub(crate) static TEST_SERVER_BASE_URL: OnceLock<String> = OnceLock::new();
276
277    pub(crate) fn test_schema_name() -> Field {
278        test_schema().get_field("name").unwrap()
279    }
280
281    pub(crate) fn test_schema_doc() -> Field {
282        test_schema().get_field("doc").unwrap()
283    }
284
285    pub(crate) fn test_schema() -> tantivy::schema::Schema {
286        let mut schema_builder = tantivy::schema::Schema::builder();
287        schema_builder.add_text_field("name", tantivy::schema::TEXT | tantivy::schema::STORED);
288        schema_builder.add_text_field("doc", tantivy::schema::TEXT | tantivy::schema::STORED);
289        schema_builder.build()
290    }
291
292    fn init_test_index_no_remote() -> ManagedDirectory {
293        let schema = test_schema();
294        let index = Index::create_in_ram(schema);
295        let index = std::thread::spawn(move || {
296            let mut writer = index.writer(15_000_000).unwrap();
297            writer
298                .add_document(doc!(
299                    test_schema_name() => "LICENSE_MIT",
300                    test_schema_doc() => include_str!("../LICENSE_MIT"),
301                ))
302                .unwrap();
303            writer
304                .add_document(doc!(
305                    test_schema_name() => "LICENSE_APACHE",
306                    test_schema_doc() => include_str!("../LICENSE_APACHE"),
307                ))
308                .unwrap();
309            writer.commit().unwrap();
310            drop(writer);
311            let ids = index.searchable_segment_ids().unwrap();
312            let writer = index.writer(15_000_000).unwrap();
313
314            tokio::runtime::Runtime::new().unwrap().block_on(async {
315                let mut writer = writer;
316                writer.merge(&ids).await.unwrap()
317            });
318
319            index
320        })
321        .join()
322        .unwrap();
323        let dir = index.directory().clone();
324        drop(index);
325
326        for path in dir.list_managed_files() {
327            if path.ends_with("meta.json") {
328                continue;
329            }
330            dir.validate_checksum(&path).unwrap();
331        }
332
333        dir
334    }
335
336    pub(crate) fn test_index() -> Index {
337        // Low chunk size to test multi-chunk reads without needing a big index.
338        let http_directory =
339            super::RemoteDirectory::<8192>::new(&TEST_SERVER_BASE_URL.get().unwrap());
340        Index::open(http_directory).unwrap()
341    }
342
343    fn run_test_server() {
344        let test_index = init_test_index_no_remote();
345
346        let server = Server::http("127.0.0.1:0").unwrap();
347
348        std::thread::spawn(move || {
349            TEST_SERVER_BASE_URL.get_or_init(|| format!("http://{}", server.server_addr()));
350            for req in server.incoming_requests() {
351                let path = req.url().trim_start_matches('/');
352                if req.method() == &Method::Get {
353                    let data = if let Some(range_header) = req
354                        .headers()
355                        .iter()
356                        .find(|h| h.field.as_str().to_ascii_lowercase() == "range")
357                    {
358                        let data = test_index
359                            .atomic_read(&PathBuf::from_str(path).unwrap())
360                            .unwrap();
361
362                        let range = {
363                            let range_str = range_header.value.to_string();
364                            let range_str = range_str.split('=').last().unwrap();
365                            let range = range_str.split('-').collect::<Vec<&str>>();
366                            let start = range[0].parse::<usize>().unwrap();
367                            let end = (1 + range[1].parse::<usize>().unwrap()).min(data.len());
368                            start..end
369                        };
370                        data[range].to_vec()
371                    } else {
372                        test_index
373                            .atomic_read(&PathBuf::from_str(path).unwrap())
374                            .unwrap()
375                    };
376                    let response = Response::from_data(data);
377                    req.respond(response).unwrap();
378                } else if req.method() == &Method::Head {
379                    let len = test_index
380                        .atomic_read(&PathBuf::from_str(path).unwrap())
381                        .unwrap()
382                        .len();
383                    let mut response = Response::from_string("".to_string());
384                    response.add_header(
385                        Header::from_bytes(&b"Content-Length"[..], len.to_string()).unwrap(),
386                    );
387                    req.respond(response).unwrap();
388                }
389            }
390        });
391    }
392
393    #[ctor::ctor]
394    fn ctor_init() {
395        run_test_server();
396    }
397
398    #[test]
399    fn test_has_meta_json() {
400        let http_directory =
401            super::RemoteDirectory::<8192>::new(&TEST_SERVER_BASE_URL.get().unwrap());
402        assert!(
403            http_directory
404                .atomic_read(std::path::Path::new("meta.json"))
405                .unwrap()
406                .len()
407                > 0
408        );
409    }
410
411    #[test]
412    fn test_has_docs() {
413        let reader = test_index().reader().unwrap();
414        assert_eq!(reader.searcher().num_docs(), 2);
415    }
416
417    #[test]
418    fn search_docs() {
419        let index = test_index();
420        let reader = index.reader().unwrap();
421        let searcher = reader.searcher();
422        let query_parser = tantivy::query::QueryParser::for_index(&index, vec![test_schema_name()]);
423        let query = query_parser.parse_query("LICENSE_MIT").unwrap();
424        let top_docs = searcher
425            .search(&query, &tantivy::collector::TopDocs::with_limit(10))
426            .unwrap();
427        assert_eq!(top_docs.len(), 1);
428    }
429}