ragit/index/commands/
build.rs

1use super::{Index, erase_lines};
2use crate::chunk;
3use crate::constant::{CHUNK_DIR_NAME, IMAGE_DIR_NAME};
4use crate::error::Error;
5use crate::index::{
6    ChunkBuildInfo,
7    FileReader,
8    IIStatus,
9    LoadMode,
10};
11use crate::uid::Uid;
12use ragit_api::audit::AuditRecord;
13use ragit_fs::{
14    WriteMode,
15    exists,
16    parent,
17    remove_file,
18    set_extension,
19    try_create_dir,
20    write_bytes,
21};
22use sha3::{Digest, Sha3_256};
23use std::collections::hash_map::{Entry, HashMap};
24use std::time::{Duration, Instant};
25use tokio::sync::mpsc;
26
27pub struct BuildResult {
28    pub success: usize,
29
30    /// Vec<(file, error)>
31    pub errors: Vec<(String, String)>,
32}
33
34impl Index {
35    pub async fn build(&mut self, workers: usize, quiet: bool) -> Result<BuildResult, Error> {
36        let mut workers = init_workers(workers, self.root_dir.clone());
37        let started_at = Instant::now();
38
39        // TODO: API is messy. I want `rag build` to be generous. When it fails to process
40        // a file, I want it to skip the file and process the other files. So, it doesn't
41        // return on error but pushes the errors to `result.errors`.
42        // Still, there're unrecoverable errors. They just kill all the workers and return immediately.
43        // That's why the code is messy with `?` and `match _ { Err(_) => {} }`
44        match self.build_worker(&mut workers, started_at, quiet).await {
45            Ok(result) => {
46                if !quiet {
47                    let elapsed_time = Instant::now().duration_since(started_at).as_secs();
48                    println!("---");
49                    println!("completed building a knowledge-base");
50                    println!("total elapsed time: {:02}:{:02}", elapsed_time / 60, elapsed_time % 60);
51                    println!(
52                        "successfully processed {} file{}",
53                        result.success,
54                        if result.success > 1 { "s" } else { "" },
55                    );
56                    println!(
57                        "{} error{}",
58                        result.errors.len(),
59                        if result.errors.len() > 1 { "s" } else { "" },
60                    );
61
62                    for (file, error) in result.errors.iter() {
63                        println!("    `{file}`: {error}");
64                    }
65                }
66
67                Ok(result)
68            },
69            Err(e) => {
70                for worker in workers.iter_mut() {
71                    let _ = worker.send(Request::Kill);
72                }
73
74                if !quiet {
75                    eprintln!("---");
76                    eprintln!("Failed to build a knowledge-base");
77                }
78
79                Err(e)
80            },
81        }
82    }
83
84    async fn build_worker(
85        &mut self,
86        workers: &mut Vec<Channel>,
87        started_at: Instant,
88        quiet: bool,
89    ) -> Result<BuildResult, Error> {
90        let mut killed_workers = vec![];
91        let mut staged_files = self.staged_files.clone();
92        let mut curr_completed_files = vec![];
93        let mut success = 0;
94        let mut errors = vec![];
95        let mut buffered_chunk_count = 0;
96        let mut flush_count = 0;
97
98        // HashMap<file, HashMap<index in file, chunk uid>>
99        let mut buffer: HashMap<String, HashMap<usize, Uid>> = HashMap::new();
100
101        // HashMap<worker id, file>
102        let mut curr_processing_file: HashMap<usize, String> = HashMap::new();
103
104        for (worker_index, worker) in workers.iter_mut().enumerate() {
105            if let Some(file) = staged_files.pop() {
106                // Previously, all the builds were in serial and this field tells
107                // which file the index is building. When something goes wrong, ragit
108                // reads this field and clean up garbages. Now, all the builds are in
109                // parallel and there's no such thing like `curr_processing_file`. But
110                // we still need to tell whether something went wrong while building
111                // and this field does that. If it's `Some(_)`, something's wrong and
112                // clean-up has to be done.
113                self.curr_processing_file = Some(String::new());
114
115                buffer.insert(file.clone(), HashMap::new());
116                curr_processing_file.insert(worker_index, file.clone());
117                worker.send(Request::BuildChunks { file }).map_err(|_| Error::MPSCError(String::from("Build worker hung up")))?;
118            }
119
120            else {
121                worker.send(Request::Kill).map_err(|_| Error::MPSCError(String::from("Build worker hung up.")))?;
122                killed_workers.push(worker_index);
123            }
124        }
125
126        self.save_to_file()?;
127        let mut has_to_erase_lines = false;
128
129        loop {
130            if !quiet {
131                self.render_build_dashboard(
132                    &buffer,
133                    &curr_completed_files,
134                    &errors,
135                    started_at.clone(),
136                    flush_count,
137                    has_to_erase_lines,
138                );
139                has_to_erase_lines = true;
140            }
141
142            for (worker_index, worker) in workers.iter_mut().enumerate() {
143                if killed_workers.contains(&worker_index) {
144                    continue;
145                }
146
147                match worker.try_recv() {
148                    Ok(msg) => match msg {
149                        Response::ChunkComplete { file, chunk_uid, index } => {
150                            buffered_chunk_count += 1;
151
152                            match buffer.entry(file.to_string()) {
153                                Entry::Occupied(mut chunks) => {
154                                    if let Some(prev_uid) = chunks.get_mut().insert(index, chunk_uid) {
155                                        return Err(Error::Internal(format!("{}th chunk of {file} is created twice: {prev_uid}, {chunk_uid}", index + 1)));
156                                    }
157                                },
158                                Entry::Vacant(e) => {
159                                    e.insert([(index, chunk_uid)].into_iter().collect());
160                                },
161                            }
162                        },
163                        Response::FileComplete { file, chunk_count } => {
164                            match buffer.get(&file) {
165                                Some(chunks) => {
166                                    if chunks.len() != chunk_count {
167                                        return Err(Error::Internal(format!("Some chunks in `{file}` are missing: expected {chunk_count} chunks, got only {} chunks.", chunks.len())));
168                                    }
169
170                                    for i in 0..chunk_count {
171                                        if !chunks.contains_key(&i) {
172                                            return Err(Error::Internal(format!(
173                                                "{} chunk of `{file}` is missing.",
174                                                match i {
175                                                    0 => String::from("1st"),
176                                                    1 => String::from("2nd"),
177                                                    2 => String::from("3rd"),
178                                                    n => format!("{}th", n + 1),
179                                                },
180                                            )));
181                                        }
182                                    }
183                                },
184                                None if chunk_count != 0 => {
185                                    return Err(Error::Internal(format!("Some chunks in `{file}` are missing: expected {chunk_count} chunks, got no chunks.")));
186                                },
187                                None => {},
188                            }
189
190                            if let Some(file) = staged_files.pop() {
191                                buffer.insert(file.clone(), HashMap::new());
192                                curr_processing_file.insert(worker_index, file.clone());
193                                worker.send(Request::BuildChunks { file }).map_err(|_| Error::MPSCError(String::from("Build worker hung up.")))?;
194                            }
195
196                            else {
197                                worker.send(Request::Kill).map_err(|_| Error::MPSCError(String::from("Build worker hung up.")))?;
198                                killed_workers.push(worker_index);
199                            }
200
201                            curr_completed_files.push(file);
202                            success += 1;
203                        },
204                        Response::Error(e) => {
205                            if let Some(file) = curr_processing_file.get(&worker_index) {
206                                errors.push((file.to_string(), format!("{e:?}")));
207
208                                // clean up garbages of the failed file
209                                let chunk_uids = buffer.get(file).unwrap().iter().map(
210                                    |(_, uid)| *uid
211                                ).collect::<Vec<_>>();
212
213                                for chunk_uid in chunk_uids.iter() {
214                                    let chunk_path = Index::get_uid_path(
215                                        &self.root_dir,
216                                        CHUNK_DIR_NAME,
217                                        *chunk_uid,
218                                        Some("chunk"),
219                                    )?;
220                                    remove_file(&chunk_path)?;
221                                    let tfidf_path = set_extension(&chunk_path, "tfidf")?;
222
223                                    if exists(&tfidf_path) {
224                                        remove_file(&tfidf_path)?;
225                                    }
226                                }
227
228                                buffered_chunk_count -= chunk_uids.len();
229                                buffer.remove(file);
230                            }
231
232                            // very small QoL hack: if there's no api key, every file will
233                            // fail with the same error. We escape before that happens
234                            if matches!(e, Error::ApiKeyNotFound { .. }) && success == 0 {
235                                return Err(e);
236                            }
237
238                            if let Some(file) = staged_files.pop() {
239                                buffer.insert(file.clone(), HashMap::new());
240                                curr_processing_file.insert(worker_index, file.clone());
241                                worker.send(Request::BuildChunks { file }).map_err(|_| Error::MPSCError(String::from("Build worker hung up.")))?;
242                            }
243
244                            else {
245                                worker.send(Request::Kill).map_err(|_| Error::MPSCError(String::from("Build worker hung up.")))?;
246                                killed_workers.push(worker_index);
247                            }
248                        },
249                    },
250                    Err(mpsc::error::TryRecvError::Empty) => {},
251                    Err(mpsc::error::TryRecvError::Disconnected) => {
252                        if !killed_workers.contains(&worker_index) {
253                            return Err(Error::MPSCError(String::from("Build worker hung up.")));
254                        }
255                    },
256                }
257            }
258
259            // It flushes and commits 20 or more files at once.
260            // TODO: this number has to be configurable
261            if curr_completed_files.len() >= 20 || killed_workers.len() == workers.len() {
262                self.staged_files = self.staged_files.iter().filter(
263                    |staged_file| !curr_completed_files.contains(staged_file)
264                ).map(
265                    |staged_file| staged_file.to_string()
266                ).collect();
267                let mut ii_buffer = HashMap::new();
268
269                for file in curr_completed_files.iter() {
270                    let real_path = Index::get_data_path(
271                        &self.root_dir,
272                        file,
273                    )?;
274
275                    if self.processed_files.contains_key(file) {
276                        self.remove_file(
277                            real_path.clone(),
278                            false,  // dry run
279                            false,  // recursive
280                            false,  // auto
281                            false,  // staged
282                            true,   // processed
283                        )?;
284                    }
285
286                    let file_uid = Uid::new_file(&self.root_dir, &real_path)?;
287                    let mut chunk_uids = buffer.get(file).unwrap().iter().map(
288                        |(index, uid)| (*index, *uid)
289                    ).collect::<Vec<_>>();
290                    chunk_uids.sort_by_key(|(index, _)| *index);
291                    let chunk_uids = chunk_uids.into_iter().map(|(_, chunk_uid)| chunk_uid).collect::<Vec<_>>();
292                    self.add_file_index(file_uid, &chunk_uids)?;
293                    self.processed_files.insert(file.to_string(), file_uid);
294
295                    match self.ii_status {
296                        IIStatus::Complete => {
297                            for chunk_uid in chunk_uids.iter() {
298                                self.update_ii_buffer(&mut ii_buffer, *chunk_uid)?;
299                            }
300                        },
301                        IIStatus::Ongoing(_)
302                        | IIStatus::Outdated => {
303                            self.ii_status = IIStatus::Outdated;
304                        },
305                        IIStatus::None => {},
306                    }
307
308                    buffer.remove(file);
309                }
310
311                if let IIStatus::Complete = self.ii_status {
312                    self.flush_ii_buffer(ii_buffer)?;
313                }
314
315                self.chunk_count += buffered_chunk_count;
316                self.reset_uid(false /* save to file */)?;
317                self.save_to_file()?;
318
319                buffered_chunk_count = 0;
320                curr_completed_files = vec![];
321                flush_count += 1;
322
323                if killed_workers.len() == workers.len() {
324                    if !quiet {
325                        self.render_build_dashboard(
326                            &buffer,
327                            &curr_completed_files,
328                            &errors,
329                            started_at.clone(),
330                            flush_count,
331                            has_to_erase_lines,
332                        );
333                    }
334
335                    break;
336                }
337            }
338
339            std::thread::sleep(Duration::from_millis(100));
340        }
341
342        self.curr_processing_file = None;
343        self.save_to_file()?;
344        self.calculate_and_save_uid()?;
345
346        // 1. If there's an error, the knowledge-base is incomplete. We should not create a summary.
347        // 2. If there's no success and no error and we already have a summary, then
348        //    `self.get_summary().is_none()` would be false, and we'll not create a summary.
349        // 3. If there's no success and no error but we don't have a summary yet, we have to create one
350        //    because a successful `rag build` must create a summary.
351        if self.build_config.summary_after_build && self.get_summary().is_none() && errors.is_empty() {
352            if !quiet {
353                println!("Creating a summary of the knowledge-base...");
354            }
355
356            self.summary(None).await?;
357        }
358
359        Ok(BuildResult {
360            success,
361            errors,
362        })
363    }
364
365    fn render_build_dashboard(
366        &self,
367        buffer: &HashMap<String, HashMap<usize, Uid>>,
368        curr_completed_files: &[String],
369        errors: &[(String, String)],
370        started_at: Instant,
371        flush_count: usize,
372        has_to_erase_lines: bool,
373    ) {
374        if has_to_erase_lines {
375            erase_lines(9);
376        }
377
378        let elapsed_time = Instant::now().duration_since(started_at).as_secs();
379        let mut curr_processing_files = vec![];
380
381        for file in buffer.keys() {
382            if !curr_completed_files.contains(file) {
383                curr_processing_files.push(format!("`{file}`"));
384            }
385        }
386
387        println!("---");
388        println!("elapsed time: {:02}:{:02}", elapsed_time / 60, elapsed_time % 60);
389        println!("staged files: {}, processed files: {}", self.staged_files.len(), self.processed_files.len());
390        println!("errors: {}", errors.len());
391        println!("committed chunks: {}", self.chunk_count);
392
393        // It messes up with `erase_lines`
394        // println!(
395        //     "currently processing files: {}",
396        //     if curr_processing_files.is_empty() {
397        //         String::from("null")
398        //     } else {
399        //         curr_processing_files.join(", ")
400        //     },
401        // );
402
403        println!(
404            "buffered files: {}, buffered chunks: {}",
405            buffer.len(),
406            buffer.values().map(|h| h.len()).sum::<usize>(),
407        );
408        println!("flush count: {flush_count}");
409        println!("model: {}", self.api_config.model);
410
411        let mut input_tokens_s = 0;
412        let mut output_tokens_s = 0;
413        let mut input_cost_s = 0;
414        let mut output_cost_s = 0;
415
416        match self.api_config.get_api_usage(&self.root_dir, "create_chunk_from") {
417            Ok(api_records) => {
418                for AuditRecord { input_tokens, output_tokens, input_cost, output_cost } in api_records.values() {
419                    input_tokens_s += *input_tokens;
420                    output_tokens_s += *output_tokens;
421                    input_cost_s += *input_cost;
422                    output_cost_s += *output_cost;
423                }
424
425                println!(
426                    "input tokens: {input_tokens_s} ({:.3}$), output tokens: {output_tokens_s} ({:.3}$)",
427                    input_cost_s as f64 / 1_000_000.0,
428                    output_cost_s as f64 / 1_000_000.0,
429                );
430            },
431            Err(_) => {
432                println!("input tokens: ??? (????$), output tokens: ??? (????$)");
433            },
434        }
435    }
436}
437
438async fn build_chunks(
439    index: &Index,
440    file: String,
441    prompt_hash: String,
442    tx_to_main: mpsc::UnboundedSender<Response>,
443) -> Result<(), Error> {
444    let real_path = Index::get_data_path(
445        &index.root_dir,
446        &file,
447    )?;
448    let mut fd = FileReader::new(
449        file.clone(),
450        real_path.clone(),
451        &index.root_dir,
452        index.build_config.clone(),
453    )?;
454    let build_info = ChunkBuildInfo::new(
455        fd.file_reader_key(),
456        prompt_hash.clone(),
457
458        // it's not a good idea to just use `api_config.model`.
459        // different `api_config.model` might point to the same model,
460        // but different `get_model_by_name().name` always refer to
461        // different models
462        index.get_model_by_name(&index.api_config.model)?.name,
463    );
464    let mut index_in_file = 0;
465    let mut previous_summary = None;
466
467    while fd.can_generate_chunk() {
468        let new_chunk = fd.generate_chunk(
469            &index,
470            build_info.clone(),
471            previous_summary.clone(),
472            index_in_file,
473        ).await?;
474        previous_summary = Some((new_chunk.clone(), (&new_chunk).into()));
475        let new_chunk_uid = new_chunk.uid;
476        let new_chunk_path = Index::get_uid_path(
477            &index.root_dir,
478            CHUNK_DIR_NAME,
479            new_chunk_uid,
480            Some("chunk"),
481        )?;
482
483        for (uid, bytes) in fd.images.iter() {
484            let image_path = Index::get_uid_path(
485                &index.root_dir,
486                IMAGE_DIR_NAME,
487                *uid,
488                Some("png"),
489            )?;
490            let parent_path = parent(&image_path)?;
491
492            if !exists(&parent_path) {
493                try_create_dir(&parent_path)?;
494            }
495
496            write_bytes(
497                &image_path,
498                &bytes,
499                WriteMode::Atomic,
500            )?;
501            index.add_image_description(*uid).await?;
502        }
503
504        chunk::save_to_file(
505            &new_chunk_path,
506            &new_chunk,
507            index.build_config.compression_threshold,
508            index.build_config.compression_level,
509            &index.root_dir,
510            true,  // create tfidf
511        )?;
512        tx_to_main.send(Response::ChunkComplete {
513            file: file.clone(),
514            index: index_in_file,
515            chunk_uid: new_chunk_uid,
516        }).map_err(|_| Error::MPSCError(String::from("Failed to send response to main")))?;
517        index_in_file += 1;
518    }
519
520    tx_to_main.send(Response::FileComplete {
521        file,
522        chunk_count: index_in_file,
523    }).map_err(|_| Error::MPSCError(String::from("Failed to send response to main")))?;
524    Ok(())
525}
526
527#[derive(Debug)]
528enum Request {
529    BuildChunks { file: String },
530    Kill,
531}
532
533#[derive(Debug)]
534enum Response {
535    FileComplete { file: String, chunk_count: usize },
536    ChunkComplete { file: String, index: usize, chunk_uid: Uid },
537    Error(Error),
538}
539
540struct Channel {
541    tx_from_main: mpsc::UnboundedSender<Request>,
542    rx_to_main: mpsc::UnboundedReceiver<Response>,
543}
544
545impl Channel {
546    pub fn send(&self, msg: Request) -> Result<(), mpsc::error::SendError<Request>> {
547        self.tx_from_main.send(msg)
548    }
549
550    pub fn try_recv(&mut self) -> Result<Response, mpsc::error::TryRecvError> {
551        self.rx_to_main.try_recv()
552    }
553}
554
555fn init_workers(n: usize, root_dir: String) -> Vec<Channel> {
556    (0..n).map(|_| init_worker(root_dir.clone())).collect()
557}
558
559fn init_worker(root_dir: String) -> Channel {
560    let (tx_to_main, rx_to_main) = mpsc::unbounded_channel();
561    let (tx_from_main, mut rx_from_main) = mpsc::unbounded_channel();
562
563    tokio::spawn(async move {
564        // Each process requires an instance of `Index`, but I found
565        // it too difficult to send the instance via mpsc channels.
566        // So I'm just instantiating new ones here.
567        // Be careful not to modify the index!
568        let index = match Index::load(
569            root_dir,
570            LoadMode::OnlyJson,
571        ) {
572            Ok(index) => index,
573            Err(e) => {
574                let _ = tx_to_main.send(Response::Error(e));
575                drop(tx_to_main);
576                return;
577            },
578        };
579        let prompt = match index.get_prompt("summarize") {
580            Ok(prompt) => prompt,
581            Err(e) => {
582                let _ = tx_to_main.send(Response::Error(e));
583                drop(tx_to_main);
584                return;
585            },
586        };
587        let mut hasher = Sha3_256::new();
588        hasher.update(prompt.as_bytes());
589        let prompt_hash = hasher.finalize();
590        let prompt_hash = format!("{prompt_hash:064x}");
591
592        while let Some(msg) = rx_from_main.recv().await {
593            match msg {
594                Request::BuildChunks { file } => match build_chunks(
595                    &index,
596                    file,
597                    prompt_hash.clone(),
598                    tx_to_main.clone(),
599                ).await {
600                    Ok(_) => {},
601                    Err(e) => {
602                        if tx_to_main.send(Response::Error(e)).is_err() {
603                            // the parent process is dead
604                            break;
605                        }
606                    },
607                },
608                Request::Kill => { break; },
609            }
610        }
611
612        drop(tx_to_main);
613        return;
614    });
615
616    Channel {
617        rx_to_main,
618        tx_from_main,
619    }
620}