1use std::path::{Path, PathBuf};
7
8use graphify_security::validate_url;
9use regex::Regex;
10use reqwest::Client;
11use thiserror::Error;
12use tracing::info;
13
14#[derive(Debug, Error)]
16pub enum IngestError {
17 #[error("HTTP error: {0}")]
18 Http(#[from] reqwest::Error),
19
20 #[error("IO error: {0}")]
21 Io(#[from] std::io::Error),
22
23 #[error("security error: {0}")]
24 Security(#[from] graphify_security::SecurityError),
25
26 #[error("ingest error: {0}")]
27 Other(String),
28}
29
30pub async fn ingest_url(url: &str, output_dir: &Path) -> Result<PathBuf, IngestError> {
35 let validated = validate_url(url)?;
36 let client = Client::new();
37
38 let url_str = validated.as_str();
39 if url_str.contains("arxiv.org") {
40 ingest_arxiv(&client, url_str, output_dir).await
41 } else if url_str.contains("twitter.com") || url_str.contains("x.com") {
42 ingest_tweet(&client, url_str, output_dir).await
43 } else if url_str.ends_with(".pdf") {
44 ingest_pdf(&client, url_str, output_dir).await
45 } else {
46 ingest_webpage(&client, url_str, output_dir).await
47 }
48}
49
50async fn ingest_arxiv(client: &Client, url: &str, out: &Path) -> Result<PathBuf, IngestError> {
52 let abs_url = url.replace("/pdf/", "/abs/");
54
55 let response = client.get(&abs_url).send().await?;
56 let html = response.text().await?;
57
58 let arxiv_id = abs_url
60 .split('/')
61 .next_back()
62 .unwrap_or("unknown")
63 .trim_end_matches(".pdf");
64
65 let title = extract_between(&html, "<title>", "</title>")
67 .unwrap_or_else(|| format!("arXiv:{}", arxiv_id));
68 let title = strip_html_tags(&title).trim().to_string();
69
70 let abstract_text = extract_between(
72 &html,
73 "<blockquote class=\"abstract mathjax\">",
74 "</blockquote>",
75 )
76 .or_else(|| extract_between(&html, "Abstract:</span>", "</blockquote>"))
77 .unwrap_or_default();
78 let abstract_text = strip_html_tags(&abstract_text).trim().to_string();
79
80 let filename = format!("arxiv_{}.md", sanitize_filename(arxiv_id));
81 let path = out.join(&filename);
82 std::fs::create_dir_all(out)?;
83
84 let content = format!(
85 "---\nsource: {}\ntype: arxiv\narxiv_id: {}\ntitle: \"{}\"\n---\n\n# {}\n\n## Abstract\n\n{}\n",
86 url, arxiv_id, title, title, abstract_text
87 );
88 std::fs::write(&path, content)?;
89
90 info!("Ingested arXiv paper: {} -> {}", arxiv_id, path.display());
91 Ok(path)
92}
93
94async fn ingest_tweet(client: &Client, url: &str, out: &Path) -> Result<PathBuf, IngestError> {
96 let oembed_url = format!(
97 "https://publish.twitter.com/oembed?url={}&omit_script=true",
98 urlencoding::encode(url)
99 );
100
101 let response = client.get(&oembed_url).send().await?;
102
103 let (author, text) = if response.status().is_success() {
104 let json: serde_json::Value = response.json().await?;
105 let author = json
106 .get("author_name")
107 .and_then(|v| v.as_str())
108 .unwrap_or("unknown")
109 .to_string();
110 let html_content = json
111 .get("html")
112 .and_then(|v| v.as_str())
113 .unwrap_or("")
114 .to_string();
115 let text = strip_html_tags(&html_content);
116 (author, text)
117 } else {
118 ("unknown".to_string(), format!("Tweet from: {}", url))
119 };
120
121 let tweet_id = url
123 .split('/')
124 .next_back()
125 .unwrap_or("unknown")
126 .split('?')
127 .next()
128 .unwrap_or("unknown");
129
130 let filename = format!("tweet_{}.md", sanitize_filename(tweet_id));
131 let path = out.join(&filename);
132 std::fs::create_dir_all(out)?;
133
134 let content = format!(
135 "---\nsource: {}\ntype: tweet\nauthor: \"{}\"\ntweet_id: {}\n---\n\n{}\n",
136 url,
137 author,
138 tweet_id,
139 text.trim()
140 );
141 std::fs::write(&path, content)?;
142
143 info!("Ingested tweet: {} -> {}", tweet_id, path.display());
144 Ok(path)
145}
146
147async fn ingest_pdf(client: &Client, url: &str, out: &Path) -> Result<PathBuf, IngestError> {
149 let response = client.get(url).send().await?;
150 let bytes = response.bytes().await?;
151
152 let filename = url.split('/').next_back().unwrap_or("document.pdf");
153 let filename = if filename.ends_with(".pdf") {
154 filename.to_string()
155 } else {
156 format!("{}.pdf", filename)
157 };
158
159 let path = out.join(&filename);
160 std::fs::create_dir_all(out)?;
161 std::fs::write(&path, &bytes)?;
162
163 info!(
164 "Ingested PDF: {} ({} bytes) -> {}",
165 url,
166 bytes.len(),
167 path.display()
168 );
169 Ok(path)
170}
171
172async fn ingest_webpage(client: &Client, url: &str, out: &Path) -> Result<PathBuf, IngestError> {
174 let response = client.get(url).send().await?;
175 let html = response.text().await?;
176
177 let title = extract_between(&html, "<title>", "</title>")
179 .map(|t| strip_html_tags(&t))
180 .unwrap_or_default();
181
182 let text = strip_scripts_and_styles(&html);
184 let text = strip_html_tags(&text);
185 let text = collapse_whitespace(&text);
186
187 let filename = sanitize_filename(url);
188 let path = out.join(format!("{}.md", filename));
189 std::fs::create_dir_all(out)?;
190
191 let content = format!(
192 "---\nsource: {}\ntype: webpage\ntitle: \"{}\"\n---\n\n# {}\n\n{}\n",
193 url,
194 title.trim(),
195 title.trim(),
196 text.trim()
197 );
198 std::fs::write(&path, content)?;
199
200 info!("Ingested webpage: {} -> {}", url, path.display());
201 Ok(path)
202}
203
204pub fn save_query_result(
209 question: &str,
210 answer: &str,
211 memory_dir: &Path,
212 query_type: &str,
213 source_nodes: Option<&[String]>,
214) -> Result<PathBuf, IngestError> {
215 std::fs::create_dir_all(memory_dir)?;
216
217 let timestamp = std::time::SystemTime::now()
218 .duration_since(std::time::UNIX_EPOCH)
219 .unwrap_or_default()
220 .as_secs();
221
222 let filename = format!("{}_{}.md", query_type, timestamp);
223 let path = memory_dir.join(&filename);
224
225 let nodes_str = source_nodes.map(|n| n.join(", ")).unwrap_or_default();
226
227 let content = format!(
228 "---\ntype: {}\ntimestamp: {}\nnodes: [{}]\n---\n\n## Question\n\n{}\n\n## Answer\n\n{}\n",
229 query_type, timestamp, nodes_str, question, answer
230 );
231 std::fs::write(&path, content)?;
232
233 info!("Saved query result: {} -> {}", query_type, path.display());
234 Ok(path)
235}
236
237fn extract_between(haystack: &str, start: &str, end: &str) -> Option<String> {
243 let start_idx = haystack.find(start)? + start.len();
244 let end_idx = haystack[start_idx..].find(end)? + start_idx;
245 Some(haystack[start_idx..end_idx].to_string())
246}
247
248fn strip_scripts_and_styles(html: &str) -> String {
250 let re_script = Regex::new(r"(?is)<script[^>]*>.*?</script>").unwrap();
251 let re_style = Regex::new(r"(?is)<style[^>]*>.*?</style>").unwrap();
252 let result = re_script.replace_all(html, "");
253 re_style.replace_all(&result, "").to_string()
254}
255
256fn strip_html_tags(html: &str) -> String {
258 let re = Regex::new(r"<[^>]+>").unwrap();
259 re.replace_all(html, "").to_string()
260}
261
262fn collapse_whitespace(text: &str) -> String {
264 let re = Regex::new(r"[ \t]+").unwrap();
265 let result = re.replace_all(text, " ");
266 let re_nl = Regex::new(r"\n{3,}").unwrap();
267 re_nl.replace_all(&result, "\n\n").to_string()
268}
269
270fn sanitize_filename(input: &str) -> String {
272 input
273 .replace("https://", "")
274 .replace("http://", "")
275 .replace(['/', '?', '&', '=', '#', ' '], "_")
276 .chars()
277 .filter(|c| c.is_alphanumeric() || *c == '_' || *c == '-' || *c == '.')
278 .take(80)
279 .collect()
280}
281
282#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_strip_html_tags() {
292 assert_eq!(strip_html_tags("<p>Hello <b>world</b></p>"), "Hello world");
293 assert_eq!(strip_html_tags("no tags"), "no tags");
294 assert_eq!(strip_html_tags("<br/>"), "");
295 }
296
297 #[test]
298 fn test_strip_scripts_and_styles() {
299 let html = "<p>Before</p><script>alert(1)</script><p>After</p>";
300 assert_eq!(strip_scripts_and_styles(html), "<p>Before</p><p>After</p>");
301
302 let html2 = "<style>.x{color:red}</style><p>Content</p>";
303 assert_eq!(strip_scripts_and_styles(html2), "<p>Content</p>");
304 }
305
306 #[test]
307 fn test_sanitize_filename() {
308 assert_eq!(
309 sanitize_filename("https://example.com/page?q=1"),
310 "example.com_page_q_1"
311 );
312 assert_eq!(sanitize_filename("simple"), "simple");
313 }
314
315 #[test]
316 fn test_sanitize_filename_max_length() {
317 let long_url = "a".repeat(200);
318 assert!(sanitize_filename(&long_url).len() <= 80);
319 }
320
321 #[test]
322 fn test_extract_between() {
323 assert_eq!(
324 extract_between("<title>Hello</title>", "<title>", "</title>"),
325 Some("Hello".to_string())
326 );
327 assert_eq!(extract_between("no markers", "<a>", "</a>"), None);
328 }
329
330 #[test]
331 fn test_collapse_whitespace() {
332 assert_eq!(collapse_whitespace("a b c"), "a b c");
333 assert_eq!(collapse_whitespace("a\n\n\n\nb"), "a\n\nb");
334 }
335
336 #[test]
337 fn test_save_query_result() {
338 let tmp = tempfile::tempdir().unwrap();
339 let path = save_query_result(
340 "What is Rust?",
341 "A systems programming language.",
342 tmp.path(),
343 "query",
344 Some(&["node1".to_string(), "node2".to_string()]),
345 )
346 .unwrap();
347
348 assert!(path.exists());
349 let content = std::fs::read_to_string(&path).unwrap();
350 assert!(content.contains("What is Rust?"));
351 assert!(content.contains("systems programming language"));
352 assert!(content.contains("node1, node2"));
353 assert!(content.contains("type: query"));
354 }
355
356 #[test]
357 fn test_save_query_result_no_nodes() {
358 let tmp = tempfile::tempdir().unwrap();
359 let path = save_query_result("question", "answer", tmp.path(), "chat", None).unwrap();
360
361 let content = std::fs::read_to_string(&path).unwrap();
362 assert!(content.contains("nodes: []"));
363 }
364}