Skip to main content

llm_multimodal/
tracker.rs

1use std::{collections::HashMap, sync::Arc};
2
3use tokio::task::JoinHandle;
4
5use super::{
6    error::{MultiModalError, MultiModalResult},
7    media::{ImageFetchConfig, MediaConnector, MediaSource},
8    types::{
9        ImageDetail, MediaContentPart, Modality, MultiModalData, MultiModalUUIDs, TrackedMedia,
10    },
11};
12
13type PendingTask = JoinHandle<MultiModalResult<TrackedMedia>>;
14
15#[derive(Debug)]
16pub struct TrackerOutput {
17    pub data: MultiModalData,
18    pub uuids: MultiModalUUIDs,
19}
20
21pub struct AsyncMultiModalTracker {
22    media_connector: Arc<MediaConnector>,
23    pending: HashMap<Modality, Vec<PendingTask>>,
24    uuids: MultiModalUUIDs,
25}
26
27impl AsyncMultiModalTracker {
28    pub fn new(media_connector: Arc<MediaConnector>) -> Self {
29        Self {
30            media_connector,
31            pending: HashMap::new(),
32            uuids: HashMap::new(),
33        }
34    }
35
36    pub fn push_part(&mut self, part: MediaContentPart) -> MultiModalResult<()> {
37        match part {
38            MediaContentPart::Text { .. } => {}
39            MediaContentPart::ImageUrl { url, detail, uuid } => {
40                let source = match url::Url::parse(&url) {
41                    Ok(parsed) if parsed.scheme() == "data" => MediaSource::DataUrl(url),
42                    _ => MediaSource::Url(url),
43                };
44                self.enqueue_image(source, detail.unwrap_or_default(), uuid);
45            }
46            MediaContentPart::ImageData {
47                data,
48                mime_type: _,
49                uuid,
50                detail,
51            } => {
52                self.enqueue_image(
53                    MediaSource::InlineBytes(data),
54                    detail.unwrap_or_default(),
55                    uuid,
56                );
57            }
58            MediaContentPart::ImageEmbeds { .. } => {
59                return Err(MultiModalError::UnsupportedContent("image_embeds"));
60            }
61        }
62        Ok(())
63    }
64
65    pub async fn finalize(mut self) -> MultiModalResult<TrackerOutput> {
66        let mut data = MultiModalData::new();
67        for (modality, tasks) in self.pending.drain() {
68            let mut items = Vec::with_capacity(tasks.len());
69            for task in tasks {
70                let media = task.await??;
71                items.push(media);
72            }
73            data.insert(modality, items);
74        }
75
76        Ok(TrackerOutput {
77            data,
78            uuids: self.uuids,
79        })
80    }
81
82    fn enqueue_image(&mut self, source: MediaSource, detail: ImageDetail, uuid: Option<String>) {
83        let modality = Modality::Image;
84        self.uuids.entry(modality).or_default().push(uuid);
85
86        let connector = Arc::clone(&self.media_connector);
87        #[expect(
88            clippy::disallowed_methods,
89            reason = "spawn handle is stored in self.pending and awaited in finalize(); fire-and-forget is intentional for concurrent media fetching"
90        )]
91        let handle = tokio::spawn(async move {
92            let frame = connector
93                .fetch_image(source, ImageFetchConfig { detail })
94                .await?;
95            Ok(TrackedMedia::Image(frame))
96        });
97
98        self.pending.entry(modality).or_default().push(handle);
99    }
100}